fugue/core/
distribution.rs

1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/docs/core/distribution.md"))]
2use rand::{Rng, RngCore};
3use rand_distr::{
4    Beta as RDBeta, Binomial as RDBinomial, Distribution as RandDistr, Exp as RDExp,
5    Gamma as RDGamma, LogNormal as RDLogNormal, Normal as RDNormal, Poisson as RDPoisson,
6};
7/// Type alias for log-probabilities.
8///
9/// Log-probabilities are represented as `f64` values. Negative infinity represents
10/// zero probability, while finite values represent the natural logarithm of probabilities.
11pub type LogF64 = f64;
12
13/// Generic interface for type-safe probability distributions.
14/// All distributions implement `Distribution<T>` where `T` is the natural return type.
15/// Example:
16///
17/// ```rust
18/// # use fugue::*;
19/// # use rand::thread_rng;
20///
21/// let mut rng = thread_rng();
22///
23/// // Type-safe sampling
24/// let coin = Bernoulli::new(0.5).unwrap();
25/// let flip: bool = coin.sample(&mut rng);  // Natural boolean
26/// let prob = coin.log_prob(&flip);
27///
28/// // Safe indexing
29/// let choice = Categorical::uniform(3).unwrap();
30/// let idx: usize = choice.sample(&mut rng);  // Safe for arrays
31/// let choice_prob = choice.log_prob(&idx);
32///
33/// // Natural counting
34/// let events = Poisson::new(3.0).unwrap();
35/// let count: u64 = events.sample(&mut rng);  // Natural count type
36/// let count_prob = events.log_prob(&count);
37/// ```
38pub trait Distribution<T>: Send + Sync {
39    /// Generate a random sample (with its natural type), `T`, from the distribution, using the provided random number generator, `rng`.
40    ///
41    /// Example:
42    /// ```rust
43    /// # use fugue::*;
44    /// # use rand::thread_rng;
45    ///
46    /// let mut rng = thread_rng();
47    ///
48    /// // Sample different distribution types
49    /// let normal_sample: f64 = Normal::new(0.0, 1.0).unwrap().sample(&mut rng);
50    /// let coin_flip: bool = Bernoulli::new(0.5).unwrap().sample(&mut rng);
51    /// let event_count: u64 = Poisson::new(3.0).unwrap().sample(&mut rng);
52    /// let category_idx: usize = Categorical::uniform(5).unwrap().sample(&mut rng);
53    /// ```
54    fn sample(&self, rng: &mut dyn RngCore) -> T;
55
56    /// Compute the log-probability density (continuous) or mass (discrete) of a value, `x`, from the distribution.
57    ///
58    /// Example:
59    /// ```rust
60    /// # use fugue::*;
61    ///
62    /// // Continuous distribution (probability density)
63    /// let normal = Normal::new(0.0, 1.0).unwrap();
64    /// let density = normal.log_prob(&0.0);  // Peak of standard normal
65    ///
66    /// // Discrete distribution (probability mass)
67    /// let coin = Bernoulli::new(0.7).unwrap();
68    /// let prob_true = coin.log_prob(&true);   // ln(0.7)
69    /// let prob_false = coin.log_prob(&false); // ln(0.3)
70    ///
71    /// // Outside support returns -∞
72    /// let poisson = Poisson::new(3.0).unwrap();
73    /// let invalid = poisson.log_prob(&u64::MAX); // Very unlikely, returns -∞
74    /// ```
75    fn log_prob(&self, x: &T) -> LogF64;
76
77    /// Clone the distribution into a boxed trait object, `Box<dyn Distribution<T>>`.
78    ///
79    /// Example:
80    /// ```rust
81    /// # use fugue::*;
82    ///
83    /// // Clone a distribution into a box
84    /// let original = Normal::new(0.0, 1.0).unwrap();
85    /// let boxed: Box<dyn Distribution<f64>> = original.clone_box();
86    ///
87    /// // Useful for storing different distribution types
88    /// let mut distributions: Vec<Box<dyn Distribution<f64>>> = vec![];
89    /// distributions.push(Normal::new(0.0, 1.0).unwrap().clone_box());
90    /// distributions.push(Uniform::new(-1.0, 1.0).unwrap().clone_box());
91    /// ```
92    fn clone_box(&self) -> Box<dyn Distribution<T>>;
93}
94
95/// A continuous distribution characterized by its mean, `mu`, and standard deviation, `sigma`.
96///
97/// Mathematical Properties:
98/// - **Support**: (-∞, +∞)
99/// - **PDF**: f(x) = (1/(σ√(2π))) × exp(-0.5 × ((x-μ)/σ)²)
100/// - **Mean**: μ
101/// - **Variance**: σ²
102/// - **68-95-99.7 rule**: ~68% within 1σ, ~95% within 2σ, ~99.7% within 3σ
103///
104/// Example:
105/// ```rust
106/// # use fugue::*;
107///
108/// // Standard normal (mean=0, std=1)
109/// let standard = sample(addr!("z"), Normal::new(0.0, 1.0).unwrap());
110///
111/// // Parameter with prior
112/// let theta = sample(addr!("theta"), Normal::new(0.0, 2.0).unwrap());
113///
114/// // Likelihood with observation
115/// let likelihood = observe(addr!("y"), Normal::new(1.5, 0.5).unwrap(), 2.0);
116///
117/// // Measurement error model
118/// let true_value = sample(addr!("true_val"), Normal::new(100.0, 10.0).unwrap());
119/// let measurement = true_value.bind(|val| {
120///     observe(addr!("measured"), Normal::new(val, 2.0).unwrap(), 98.5)
121/// });
122/// ```
123#[derive(Clone, Copy, Debug)]
124pub struct Normal {
125    /// Mean of the normal distribution.
126    mu: f64,
127    /// Standard deviation of the normal distribution (must be positive).
128    sigma: f64,
129}
130impl Normal {
131    /// Create a new Normal distribution with validated parameters.
132    pub fn new(mu: f64, sigma: f64) -> crate::error::FugueResult<Self> {
133        if !mu.is_finite() {
134            return Err(crate::error::FugueError::invalid_parameters(
135                "Normal",
136                "Mean (mu) must be finite",
137                crate::error::ErrorCode::InvalidMean,
138            )
139            .with_context("mu", format!("{}", mu)));
140        }
141        if sigma <= 0.0 || !sigma.is_finite() {
142            return Err(crate::error::FugueError::invalid_parameters(
143                "Normal",
144                "Standard deviation (sigma) must be positive and finite",
145                crate::error::ErrorCode::InvalidVariance,
146            )
147            .with_context("sigma", format!("{}", sigma))
148            .with_context("expected", "> 0.0 and finite"));
149        }
150        Ok(Normal { mu, sigma })
151    }
152
153    /// Get the mean of the distribution.
154    pub fn mu(&self) -> f64 {
155        self.mu
156    }
157
158    /// Get the standard deviation of the distribution.
159    pub fn sigma(&self) -> f64 {
160        self.sigma
161    }
162}
163impl Distribution<f64> for Normal {
164    fn sample(&self, rng: &mut dyn RngCore) -> f64 {
165        if self.sigma <= 0.0 {
166            return f64::NAN;
167        }
168        RDNormal::new(self.mu, self.sigma).unwrap().sample(rng)
169    }
170    fn log_prob(&self, x: &f64) -> LogF64 {
171        // Parameter validation
172        if self.sigma <= 0.0 || !self.sigma.is_finite() || !self.mu.is_finite() || !x.is_finite() {
173            return f64::NEG_INFINITY;
174        }
175
176        // Numerically stable computation
177        let z = (x - self.mu) / self.sigma;
178
179        // Prevent overflow for extreme values (|z| > 37 gives exp(-z²/2) < machine epsilon)
180        if z.abs() > 37.0 {
181            return f64::NEG_INFINITY;
182        }
183
184        // Use precomputed constant for better precision
185        const LN_2PI: f64 = 1.837_877_066_409_345_6; // ln(2π)
186        -0.5 * z * z - self.sigma.ln() - 0.5 * LN_2PI
187    }
188    fn clone_box(&self) -> Box<dyn Distribution<f64>> {
189        Box::new(*self)
190    }
191}
192
193/// A continuous distribution that assigns equal probability density to all values within a specified interval, from `low` to `high`.
194///
195/// Commonly used as an uninformative prior when you want to express complete uncertainty over a bounded range.
196///
197/// Mathematical Properties:
198/// - **Support**: [low, high)
199/// - **PDF**: f(x) = 1/(high-low) for low ≤ x < high, 0 otherwise
200/// - **Mean**: (low + high) / 2
201/// - **Variance**: (high - low)² / 12
202///
203/// Example:
204///
205/// ```rust
206/// # use fugue::*;
207///
208/// // Unit interval [0, 1)
209/// let unit = sample(addr!("p"), Uniform::new(0.0, 1.0).unwrap());
210///
211/// // Symmetric around zero
212/// let symmetric = sample(addr!("x"), Uniform::new(-5.0, 5.0).unwrap());
213///
214/// // Uninformative prior for weight
215/// let weight = sample(addr!("weight"), Uniform::new(0.0, 100.0).unwrap());
216///
217/// // Random angle in radians
218/// let angle = sample(addr!("angle"), Uniform::new(0.0, 2.0 * std::f64::consts::PI).unwrap());
219/// ```
220#[derive(Clone, Copy, Debug)]
221pub struct Uniform {
222    /// Lower bound of the uniform distribution (inclusive).
223    low: f64,
224    /// Upper bound of the uniform distribution (exclusive).
225    high: f64,
226}
227impl Uniform {
228    /// Create a new Uniform distribution with validated parameters.
229    pub fn new(low: f64, high: f64) -> crate::error::FugueResult<Self> {
230        if !low.is_finite() || !high.is_finite() {
231            return Err(crate::error::FugueError::invalid_parameters(
232                "Uniform",
233                "Bounds must be finite",
234                crate::error::ErrorCode::InvalidRange,
235            )
236            .with_context("low", format!("{}", low))
237            .with_context("high", format!("{}", high)));
238        }
239        if low >= high {
240            return Err(crate::error::FugueError::invalid_parameters(
241                "Uniform",
242                "Lower bound must be less than upper bound",
243                crate::error::ErrorCode::InvalidRange,
244            )
245            .with_context("low", format!("{}", low))
246            .with_context("high", format!("{}", high)));
247        }
248        Ok(Uniform { low, high })
249    }
250
251    /// Get the lower bound.
252    pub fn low(&self) -> f64 {
253        self.low
254    }
255
256    /// Get the upper bound.
257    pub fn high(&self) -> f64 {
258        self.high
259    }
260}
261impl Distribution<f64> for Uniform {
262    fn sample(&self, rng: &mut dyn RngCore) -> f64 {
263        // Parameter validation
264        if self.low >= self.high || !self.low.is_finite() || !self.high.is_finite() {
265            return f64::NAN;
266        }
267        Rng::gen_range(rng, self.low..self.high)
268    }
269    fn log_prob(&self, x: &f64) -> LogF64 {
270        // Parameter validation
271        if self.low >= self.high
272            || !self.low.is_finite()
273            || !self.high.is_finite()
274            || !x.is_finite()
275        {
276            return f64::NEG_INFINITY;
277        }
278
279        // Check support with proper boundary handling
280        if *x < self.low || *x >= self.high {
281            f64::NEG_INFINITY
282        } else {
283            let width = self.high - self.low;
284            if width <= 0.0 {
285                f64::NEG_INFINITY
286            } else {
287                -width.ln()
288            }
289        }
290    }
291    fn clone_box(&self) -> Box<dyn Distribution<f64>> {
292        Box::new(*self)
293    }
294}
295
296/// A continuous distribution where the logarithm follows a normal distribution.
297///
298/// Useful for modeling positive-valued quantities that are naturally multiplicative or skewed.
299///
300/// Mathematical Properties:
301/// - **Support**: (0, +∞)
302/// - **PDF**: f(x) = (1/(xσ√(2π))) × exp(-0.5 × ((ln(x)-μ)/σ)²)
303/// - **Mean**: exp(μ + σ²/2)
304/// - **Variance**: (exp(σ²) - 1) × exp(2μ + σ²)
305/// - **Relationship**: If X ~ LogNormal(μ, σ), then ln(X) ~ Normal(μ, σ)
306///
307/// Example:
308/// ```rust
309/// # use fugue::*;
310///
311/// // Standard log-normal (median = 1)
312/// let standard = sample(addr!("x"), LogNormal::new(0.0, 1.0).unwrap());
313///
314/// // Positive scale parameter
315/// let scale = sample(addr!("scale"), LogNormal::new(0.0, 0.5).unwrap());
316///
317/// // Income distribution
318/// let income = sample(addr!("income"), LogNormal::new(10.0, 0.8).unwrap())
319///     .map(|x| x.round() as u64); // Convert to dollars
320///
321/// // Multiplicative error model
322/// let true_value = 100.0;
323/// let measured = sample(addr!("error"), LogNormal::new(0.0, 0.1).unwrap())
324///     .map(move |error| true_value * error);
325/// ```
326#[derive(Clone, Copy, Debug)]
327pub struct LogNormal {
328    /// Mean of the underlying normal distribution.
329    mu: f64,
330    /// Standard deviation of the underlying normal distribution (must be positive).
331    sigma: f64,
332}
333impl LogNormal {
334    /// Create a new LogNormal distribution with validated parameters.
335    pub fn new(mu: f64, sigma: f64) -> crate::error::FugueResult<Self> {
336        if !mu.is_finite() {
337            return Err(crate::error::FugueError::invalid_parameters(
338                "LogNormal",
339                "Mean (mu) must be finite",
340                crate::error::ErrorCode::InvalidMean,
341            )
342            .with_context("mu", format!("{}", mu)));
343        }
344        if sigma <= 0.0 || !sigma.is_finite() {
345            return Err(crate::error::FugueError::invalid_parameters(
346                "LogNormal",
347                "Standard deviation (sigma) must be positive and finite",
348                crate::error::ErrorCode::InvalidVariance,
349            )
350            .with_context("sigma", format!("{}", sigma))
351            .with_context("expected", "> 0.0 and finite"));
352        }
353        Ok(LogNormal { mu, sigma })
354    }
355
356    /// Get the mean of the underlying normal distribution.
357    pub fn mu(&self) -> f64 {
358        self.mu
359    }
360
361    /// Get the standard deviation of the underlying normal distribution.
362    pub fn sigma(&self) -> f64 {
363        self.sigma
364    }
365}
366impl Distribution<f64> for LogNormal {
367    fn sample(&self, rng: &mut dyn RngCore) -> f64 {
368        if self.sigma <= 0.0 {
369            return f64::NAN;
370        }
371        RDLogNormal::new(self.mu, self.sigma).unwrap().sample(rng)
372    }
373    fn log_prob(&self, x: &f64) -> LogF64 {
374        // Parameter and input validation
375        if self.sigma <= 0.0 || !self.sigma.is_finite() || !self.mu.is_finite() {
376            return f64::NEG_INFINITY;
377        }
378        if *x <= 0.0 || !x.is_finite() {
379            return f64::NEG_INFINITY;
380        }
381
382        // Numerically stable computation
383        let lx = x.ln();
384        let z = (lx - self.mu) / self.sigma;
385
386        // Prevent overflow
387        if z.abs() > 37.0 {
388            return f64::NEG_INFINITY;
389        }
390
391        // Stable computation: log_prob = -0.5*z² - ln(x) - ln(σ) - 0.5*ln(2π)
392        const LN_2PI: f64 = 1.837_877_066_409_345_6; // ln(2π)
393        -0.5 * z * z - lx - self.sigma.ln() - 0.5 * LN_2PI
394    }
395    fn clone_box(&self) -> Box<dyn Distribution<f64>> {
396        Box::new(*self)
397    }
398}
399
400/// A continuous distribution often used to model waiting times between events.
401///
402/// Characterized by the memoryless property.
403///
404/// Mathematical Properties:
405/// - **Support**: [0, +∞)
406/// - **PDF**: f(x) = λ × exp(-λx) for x ≥ 0
407/// - **Mean**: 1 / λ
408/// - **Variance**: 1 / λ²
409/// - **Memoryless**: P(X > s + t | X > s) = P(X > t)
410///
411/// Example:
412/// ```rust
413/// # use fugue::*;
414///
415/// // Average wait time of 2 minutes (rate = 0.5 per minute)
416/// let wait_time = sample(addr!("wait"), Exponential::new(0.5).unwrap());
417///
418/// // Service time model
419/// let service = sample(addr!("service_time"), Exponential::new(1.5).unwrap())
420///     .bind(|time| {
421///         if time > 5.0 {
422///             pure("slow")
423///         } else {
424///             pure("fast")
425///         }
426///     });
427///
428/// // Observe actual waiting time
429/// let observed = observe(addr!("actual_wait"), Exponential::new(0.3).unwrap(), 4.2);
430/// ```
431#[derive(Clone, Copy, Debug)]
432pub struct Exponential {
433    /// Rate parameter λ of the exponential distribution (must be positive).
434    rate: f64,
435}
436impl Exponential {
437    /// Create a new Exponential distribution with validated parameters.
438    pub fn new(rate: f64) -> crate::error::FugueResult<Self> {
439        if rate <= 0.0 || !rate.is_finite() {
440            return Err(crate::error::FugueError::invalid_parameters(
441                "Exponential",
442                "Rate parameter must be positive and finite",
443                crate::error::ErrorCode::InvalidRate,
444            )
445            .with_context("rate", format!("{}", rate))
446            .with_context("expected", "> 0.0 and finite"));
447        }
448        Ok(Exponential { rate })
449    }
450
451    /// Get the rate parameter.
452    pub fn rate(&self) -> f64 {
453        self.rate
454    }
455}
456impl Distribution<f64> for Exponential {
457    fn sample(&self, rng: &mut dyn RngCore) -> f64 {
458        if self.rate <= 0.0 {
459            return f64::NAN;
460        }
461        RDExp::new(self.rate).unwrap().sample(rng)
462    }
463    fn log_prob(&self, x: &f64) -> LogF64 {
464        // Parameter validation
465        if self.rate <= 0.0 || !self.rate.is_finite() || !x.is_finite() {
466            return f64::NEG_INFINITY;
467        }
468
469        if *x < 0.0 {
470            f64::NEG_INFINITY
471        } else {
472            // Check for overflow: if rate * x > 700, exp(-rate*x) underflows
473            if self.rate * x > 700.0 {
474                return f64::NEG_INFINITY;
475            }
476            self.rate.ln() - self.rate * x
477        }
478    }
479    fn clone_box(&self) -> Box<dyn Distribution<f64>> {
480        Box::new(*self)
481    }
482}
483
484/// A discrete distribution for binary outcomes (true/false, success/failure).
485///
486/// Returns `bool` directly for type-safe boolean logic.
487///
488/// Mathematical Properties:
489/// - **Support**: {false, true}
490/// - **PMF**: P(X = true) = p, P(X = false) = 1 - p
491/// - **Mean**: p
492/// - **Variance**: p(1 - p)
493///
494/// Example:
495/// ```rust
496/// # use fugue::*;
497///
498/// // Fair coin flip
499/// let coin = sample(addr!("coin"), Bernoulli::new(0.5).unwrap());
500/// let result = coin.bind(|heads| {
501///     if heads {
502///         pure("Heads!")
503///     } else {
504///         pure("Tails!")
505///     }
506/// });
507///
508/// // Biased coin with observation
509/// let biased = observe(addr!("biased_coin"), Bernoulli::new(0.7).unwrap(), true);
510/// ```
511#[derive(Clone, Copy, Debug)]
512pub struct Bernoulli {
513    /// Probability of success (must be in [0, 1]).
514    p: f64,
515}
516impl Bernoulli {
517    /// Create a new Bernoulli distribution with validated parameters.
518    pub fn new(p: f64) -> crate::error::FugueResult<Self> {
519        if !p.is_finite() || !(0.0..=1.0).contains(&p) {
520            return Err(crate::error::FugueError::invalid_parameters(
521                "Bernoulli",
522                "Probability must be in [0, 1]",
523                crate::error::ErrorCode::InvalidProbability,
524            )
525            .with_context("p", format!("{}", p))
526            .with_context("expected", "[0.0, 1.0]"));
527        }
528        Ok(Bernoulli { p })
529    }
530
531    /// Get the success probability.
532    pub fn p(&self) -> f64 {
533        self.p
534    }
535}
536impl Distribution<bool> for Bernoulli {
537    fn sample(&self, rng: &mut dyn RngCore) -> bool {
538        if self.p < 0.0 || self.p > 1.0 || !self.p.is_finite() {
539            return false; // Default to false for invalid parameters
540        }
541        use rand::Rng;
542        rng.gen::<f64>() < self.p
543    }
544    fn log_prob(&self, x: &bool) -> LogF64 {
545        // Parameter validation
546        if self.p < 0.0 || self.p > 1.0 || !self.p.is_finite() {
547            return f64::NEG_INFINITY;
548        }
549
550        if *x {
551            // P(X = true) = p
552            if self.p <= 0.0 {
553                f64::NEG_INFINITY
554            } else {
555                self.p.ln()
556            }
557        } else {
558            // P(X = false) = 1 - p
559            if self.p >= 1.0 {
560                f64::NEG_INFINITY
561            } else {
562                (1.0 - self.p).ln()
563            }
564        }
565    }
566    fn clone_box(&self) -> Box<dyn Distribution<bool>> {
567        Box::new(*self)
568    }
569}
570
571/// A discrete distribution for choosing among multiple categories with specified probabilities.
572///
573/// Returns `usize` for safe array indexing.
574///
575/// Mathematical Properties:
576/// - **Support**: {0, 1, ..., k-1} where k = number of categories
577/// - **PMF**: P(X = i) = probs[i]
578/// - **Mean**: Σ(i × probs[i])
579/// - **Variance**: Σ(i² × probs[i]) - mean²
580///
581/// Example:
582/// ```rust
583/// # use fugue::*;
584///
585/// // Custom probabilities
586/// let weighted = Categorical::new(vec![0.1, 0.2, 0.3, 0.4]).unwrap();
587///
588/// // Uniform distribution over k categories
589/// let uniform = Categorical::uniform(4).unwrap();
590///
591/// // Choose from three options
592/// let options = vec!["red", "green", "blue"];
593/// let choice = sample(addr!("color"), Categorical::new(vec![0.5, 0.3, 0.2]).unwrap())
594///     .map(move |idx| options[idx].to_string());
595///
596/// // Observe a specific choice
597/// let observed = observe(addr!("user_choice"),
598///     Categorical::uniform(3).unwrap(), 1usize);
599/// ```
600#[derive(Clone, Debug)]
601pub struct Categorical {
602    /// Probabilities for each category (should sum to 1.0).
603    probs: Vec<f64>,
604}
605impl Categorical {
606    /// Create a new Categorical distribution with validated parameters.
607    pub fn new(probs: Vec<f64>) -> crate::error::FugueResult<Self> {
608        if probs.is_empty() {
609            return Err(crate::error::FugueError::invalid_parameters(
610                "Categorical",
611                "Probability vector cannot be empty",
612                crate::error::ErrorCode::InvalidProbability,
613            )
614            .with_context("length", "0"));
615        }
616
617        let sum: f64 = probs.iter().sum();
618        if (sum - 1.0).abs() > 1e-6 {
619            return Err(crate::error::FugueError::invalid_parameters(
620                "Categorical",
621                "Probabilities must sum to 1.0",
622                crate::error::ErrorCode::InvalidProbability,
623            )
624            .with_context("sum", format!("{:.6}", sum))
625            .with_context("expected", "1.0")
626            .with_context("tolerance", "1e-6"));
627        }
628
629        for (i, &p) in probs.iter().enumerate() {
630            if !p.is_finite() || p < 0.0 {
631                return Err(crate::error::FugueError::invalid_parameters(
632                    "Categorical",
633                    "All probabilities must be non-negative and finite",
634                    crate::error::ErrorCode::InvalidProbability,
635                )
636                .with_context("index", format!("{}", i))
637                .with_context("value", format!("{}", p))
638                .with_context("expected", ">= 0.0 and finite"));
639            }
640        }
641
642        Ok(Categorical { probs })
643    }
644
645    /// Create a uniform categorical distribution over k categories.
646    pub fn uniform(k: usize) -> crate::error::FugueResult<Self> {
647        if k == 0 {
648            return Err(crate::error::FugueError::invalid_parameters(
649                "Categorical",
650                "Number of categories must be positive",
651                crate::error::ErrorCode::InvalidCount,
652            )
653            .with_context("k", "0"));
654        }
655
656        let prob = 1.0 / k as f64;
657        let probs = vec![prob; k];
658        Ok(Categorical { probs })
659    }
660
661    /// Get the probability vector.
662    pub fn probs(&self) -> &[f64] {
663        &self.probs
664    }
665
666    /// Get the number of categories.
667    pub fn len(&self) -> usize {
668        self.probs.len()
669    }
670
671    /// Check if the distribution has no categories.
672    pub fn is_empty(&self) -> bool {
673        self.probs.is_empty()
674    }
675}
676impl Distribution<usize> for Categorical {
677    fn sample(&self, rng: &mut dyn RngCore) -> usize {
678        // Parameter validation
679        if self.probs.is_empty() {
680            return 0;
681        }
682
683        let prob_sum: f64 = self.probs.iter().sum();
684        if (prob_sum - 1.0).abs() > 1e-6 || self.probs.iter().any(|&p| p < 0.0 || !p.is_finite()) {
685            return 0;
686        }
687
688        use rand::Rng;
689        let u: f64 = rng.gen();
690        let mut cum = 0.0;
691        for (i, &p) in self.probs.iter().enumerate() {
692            cum += p;
693            if u <= cum {
694                return i;
695            }
696        }
697        self.probs.len() - 1
698    }
699    fn log_prob(&self, x: &usize) -> LogF64 {
700        // Parameter validation
701        if self.probs.is_empty() || *x >= self.probs.len() {
702            return f64::NEG_INFINITY;
703        }
704
705        let prob_sum: f64 = self.probs.iter().sum();
706        if (prob_sum - 1.0).abs() > 1e-6 || self.probs.iter().any(|&p| p < 0.0 || !p.is_finite()) {
707            return f64::NEG_INFINITY;
708        }
709
710        if self.probs[*x] <= 0.0 {
711            f64::NEG_INFINITY
712        } else {
713            self.probs[*x].ln()
714        }
715    }
716    fn clone_box(&self) -> Box<dyn Distribution<usize>> {
717        Box::new(self.clone())
718    }
719}
720
721/// A continuous distribution on the interval (0, 1), commonly used for modeling probabilities and proportions.
722///
723/// Conjugate prior for Bernoulli/Binomial distributions.
724///
725/// Mathematical Properties:
726/// - **Support**: (0, 1)
727/// - **PDF**: f(x) = (x^(α-1) × (1-x)^(β-1)) / B(α,β)
728/// - **Mean**: α / (α + β)
729/// - **Variance**: (αβ) / ((α+β)²(α+β+1))
730///
731/// Example:
732/// ```rust
733/// # use fugue::*;
734///
735/// // Uniform on [0,1]
736/// let uniform = sample(addr!("p"), Beta::new(1.0, 1.0).unwrap());
737///
738/// // Prior for success probability
739/// let prob_prior = sample(addr!("success_rate"), Beta::new(2.0, 5.0).unwrap());
740///
741/// // Conjugate prior-likelihood pair
742/// let model = sample(addr!("p"), Beta::new(3.0, 7.0).unwrap())
743///     .bind(|p| observe(addr!("trial"), Bernoulli::new(p).unwrap(), true));
744///
745/// // Skewed towards 0 (beta > alpha)
746/// let skewed = sample(addr!("proportion"), Beta::new(2.0, 8.0).unwrap());
747/// ```
748#[derive(Clone, Copy, Debug)]
749pub struct Beta {
750    /// First shape parameter α (must be positive).
751    alpha: f64,
752    /// Second shape parameter β (must be positive).
753    beta: f64,
754}
755impl Beta {
756    /// Create a new Beta distribution with validated parameters.
757    pub fn new(alpha: f64, beta: f64) -> crate::error::FugueResult<Self> {
758        if alpha <= 0.0 || !alpha.is_finite() {
759            return Err(crate::error::FugueError::invalid_parameters(
760                "Beta",
761                "Alpha parameter must be positive and finite",
762                crate::error::ErrorCode::InvalidShape,
763            )
764            .with_context("alpha", format!("{}", alpha))
765            .with_context("expected", "> 0.0 and finite"));
766        }
767        if beta <= 0.0 || !beta.is_finite() {
768            return Err(crate::error::FugueError::invalid_parameters(
769                "Beta",
770                "Beta parameter must be positive and finite",
771                crate::error::ErrorCode::InvalidShape,
772            )
773            .with_context("beta", format!("{}", beta))
774            .with_context("expected", "> 0.0 and finite"));
775        }
776        Ok(Beta { alpha, beta })
777    }
778
779    /// Get the alpha parameter.
780    pub fn alpha(&self) -> f64 {
781        self.alpha
782    }
783
784    /// Get the beta parameter.
785    pub fn beta(&self) -> f64 {
786        self.beta
787    }
788}
789impl Distribution<f64> for Beta {
790    fn sample(&self, rng: &mut dyn RngCore) -> f64 {
791        if self.alpha <= 0.0 || self.beta <= 0.0 {
792            return f64::NAN;
793        }
794        RDBeta::new(self.alpha, self.beta).unwrap().sample(rng)
795    }
796    fn log_prob(&self, x: &f64) -> LogF64 {
797        // Parameter validation
798        if self.alpha <= 0.0
799            || self.beta <= 0.0
800            || !self.alpha.is_finite()
801            || !self.beta.is_finite()
802            || !x.is_finite()
803        {
804            return f64::NEG_INFINITY;
805        }
806
807        // Support validation
808        if *x <= 0.0 || *x >= 1.0 {
809            return f64::NEG_INFINITY;
810        }
811
812        // Handle edge cases near boundaries
813        if *x < 1e-100 || *x > 1.0 - 1e-100 {
814            return f64::NEG_INFINITY;
815        }
816
817        // Numerically stable computation using log-gamma
818        // log Beta(x; α, β) = (α-1)ln(x) + (β-1)ln(1-x) - log B(α,β)
819        let log_beta_fn = libm::lgamma(self.alpha) + libm::lgamma(self.beta)
820            - libm::lgamma(self.alpha + self.beta);
821
822        let ln_x = x.ln();
823        let ln_1_minus_x = (1.0 - x).ln();
824
825        // Check for extreme log values
826        if ln_x < -700.0 || ln_1_minus_x < -700.0 {
827            return f64::NEG_INFINITY;
828        }
829
830        (self.alpha - 1.0) * ln_x + (self.beta - 1.0) * ln_1_minus_x - log_beta_fn
831    }
832    fn clone_box(&self) -> Box<dyn Distribution<f64>> {
833        Box::new(*self)
834    }
835}
836
837/// A continuous distribution over positive real numbers, parameterized by shape and rate.
838///
839/// Commonly used for modeling waiting times and as a conjugate prior for Poisson distributions.
840///
841/// Mathematical Properties:
842/// - **Support**: (0, +∞)
843/// - **PDF**: f(x) = (λ^k / Γ(k)) × x^(k-1) × exp(-λx)
844/// - **Mean**: k / λ
845/// - **Variance**: k / λ²
846///
847/// Example:
848/// ```rust
849/// # use fugue::*;
850///
851/// // Shape=1 gives Exponential distribution
852/// let exponential_like = sample(addr!("wait_time"), Gamma::new(1.0, 2.0).unwrap());
853///
854/// // Prior for precision parameter
855/// let precision = sample(addr!("precision"), Gamma::new(2.0, 1.0).unwrap());
856///
857/// // Conjugate prior for Poisson rate
858/// let model = sample(addr!("rate"), Gamma::new(3.0, 2.0).unwrap())
859///     .bind(|lambda| observe(addr!("count"), Poisson::new(lambda).unwrap(), 5u64));
860///
861/// // Scale parameter (rate = 1/scale)
862/// let scale_param = sample(addr!("scale"), Gamma::new(2.0, 0.5).unwrap()); // mean = 4
863/// ```
864#[derive(Clone, Copy, Debug)]
865pub struct Gamma {
866    /// Shape parameter k (must be positive).
867    shape: f64,
868    /// Rate parameter λ (must be positive).
869    rate: f64,
870}
871impl Gamma {
872    /// Create a new Gamma distribution with validated parameters.
873    pub fn new(shape: f64, rate: f64) -> crate::error::FugueResult<Self> {
874        if shape <= 0.0 || !shape.is_finite() {
875            return Err(crate::error::FugueError::invalid_parameters(
876                "Gamma",
877                "Shape parameter must be positive and finite",
878                crate::error::ErrorCode::InvalidShape,
879            )
880            .with_context("shape", format!("{}", shape))
881            .with_context("expected", "> 0.0 and finite"));
882        }
883        if rate <= 0.0 || !rate.is_finite() {
884            return Err(crate::error::FugueError::invalid_parameters(
885                "Gamma",
886                "Rate parameter must be positive and finite",
887                crate::error::ErrorCode::InvalidRate,
888            )
889            .with_context("rate", format!("{}", rate))
890            .with_context("expected", "> 0.0 and finite"));
891        }
892        Ok(Gamma { shape, rate })
893    }
894
895    /// Get the shape parameter.
896    pub fn shape(&self) -> f64 {
897        self.shape
898    }
899
900    /// Get the rate parameter.
901    pub fn rate(&self) -> f64 {
902        self.rate
903    }
904}
905impl Distribution<f64> for Gamma {
906    fn sample(&self, rng: &mut dyn RngCore) -> f64 {
907        if self.shape <= 0.0 || self.rate <= 0.0 {
908            return f64::NAN;
909        }
910        RDGamma::new(self.shape, 1.0 / self.rate)
911            .unwrap()
912            .sample(rng)
913    }
914    fn log_prob(&self, x: &f64) -> LogF64 {
915        // Parameter validation
916        if self.shape <= 0.0
917            || self.rate <= 0.0
918            || !self.shape.is_finite()
919            || !self.rate.is_finite()
920            || !x.is_finite()
921        {
922            return f64::NEG_INFINITY;
923        }
924
925        if *x <= 0.0 {
926            return f64::NEG_INFINITY;
927        }
928
929        // Check for overflow conditions
930        if self.rate * x > 700.0 || x.ln() * (self.shape - 1.0) < -700.0 {
931            return f64::NEG_INFINITY;
932        }
933
934        // Numerically stable computation
935        // log Gamma(x; k, λ) = k*ln(λ) + (k-1)*ln(x) - λ*x - ln Γ(k)
936        let log_rate = self.rate.ln();
937        let log_x = x.ln();
938        let log_gamma_shape = libm::lgamma(self.shape);
939
940        self.shape * log_rate + (self.shape - 1.0) * log_x - self.rate * x - log_gamma_shape
941    }
942    fn clone_box(&self) -> Box<dyn Distribution<f64>> {
943        Box::new(*self)
944    }
945}
946
947/// A discrete distribution representing the number of successes in n independent trials, with probability of success p.
948///
949/// Returns `u64` for natural success counting.
950///
951/// Mathematical Properties:
952/// - **Support**: {0, 1, ..., n}
953/// - **PMF**: P(X = k) = C(n,k) × p^k × (1-p)^(n-k)
954/// - **Mean**: n × p
955/// - **Variance**: n × p × (1-p)
956///
957/// Example:
958/// ```rust
959/// # use fugue::*;
960///
961/// // 10 coin flips
962/// let successes = sample(addr!("heads"), Binomial::new(10, 0.5).unwrap())
963///     .bind(|count| {
964///         let rate = count as f64 / 10.0;
965///         pure(format!("Success rate: {:.1}%", rate * 100.0))
966///     });
967///
968/// // Clinical trial
969/// let trial = sample(addr!("success_rate"), Beta::new(1.0, 1.0).unwrap())
970///     .bind(|p| sample(addr!("successes"), Binomial::new(100, p).unwrap()));
971///
972/// // Observe trial results
973/// let observed = observe(addr!("trial_successes"), Binomial::new(20, 0.3).unwrap(), 7u64);
974/// ```
975#[derive(Clone, Copy, Debug)]
976pub struct Binomial {
977    /// Number of trials.
978    n: u64,
979    /// Probability of success on each trial (must be in [0, 1]).
980    p: f64,
981}
982impl Binomial {
983    /// Create a new Binomial distribution with validated parameters.
984    pub fn new(n: u64, p: f64) -> crate::error::FugueResult<Self> {
985        if !p.is_finite() || !(0.0..=1.0).contains(&p) {
986            return Err(crate::error::FugueError::invalid_parameters(
987                "Binomial",
988                "Probability must be in [0, 1]",
989                crate::error::ErrorCode::InvalidProbability,
990            )
991            .with_context("p", format!("{}", p))
992            .with_context("expected", "[0.0, 1.0]"));
993        }
994        Ok(Binomial { n, p })
995    }
996
997    /// Get the number of trials.
998    pub fn n(&self) -> u64 {
999        self.n
1000    }
1001
1002    /// Get the success probability.
1003    pub fn p(&self) -> f64 {
1004        self.p
1005    }
1006}
1007impl Distribution<u64> for Binomial {
1008    fn sample(&self, rng: &mut dyn RngCore) -> u64 {
1009        RDBinomial::new(self.n, self.p).unwrap().sample(rng)
1010    }
1011    fn log_prob(&self, x: &u64) -> LogF64 {
1012        let k = *x;
1013        if k > self.n {
1014            return f64::NEG_INFINITY;
1015        }
1016        // log Binomial(k; n, p) = log C(n,k) + k*ln(p) + (n-k)*ln(1-p)
1017        let log_binom_coeff = libm::lgamma(self.n as f64 + 1.0)
1018            - libm::lgamma(k as f64 + 1.0)
1019            - libm::lgamma((self.n - k) as f64 + 1.0);
1020        log_binom_coeff + (k as f64) * self.p.ln() + ((self.n - k) as f64) * (1.0 - self.p).ln()
1021    }
1022    fn clone_box(&self) -> Box<dyn Distribution<u64>> {
1023        Box::new(*self)
1024    }
1025}
1026
1027/// A discrete distribution for modeling the number of events occurring in a fixed interval.
1028///
1029/// Returns `u64` for natural counting arithmetic.
1030///
1031/// Mathematical Properties:
1032/// - **Support**: {0, 1, 2, 3, ...}
1033/// - **PMF**: P(X = k) = (λ^k × e^(-λ)) / k!
1034/// - **Mean**: λ
1035/// - **Variance**: λ
1036/// - **Memoryless**: Past events don't affect future rates
1037///
1038/// Example:
1039/// ```rust
1040/// # use fugue::*;
1041///
1042/// // Model event counts
1043/// let events = sample(addr!("events"), Poisson::new(3.0).unwrap())
1044///     .bind(|count| {
1045///         let status = match count {
1046///             0 => "No events",
1047///             1 => "Single event",
1048///             n if n > 10 => "High activity",
1049///             _ => "Normal activity"
1050///         };
1051///         pure(status.to_string())
1052///     });
1053///
1054/// // Hierarchical model with Gamma prior
1055/// let hierarchical = sample(addr!("rate"), Gamma::new(2.0, 1.0).unwrap())
1056///     .bind(|lambda| sample(addr!("count"), Poisson::new(lambda).unwrap()));
1057///
1058/// // Observe count data
1059/// let observed = observe(addr!("observed_count"), Poisson::new(4.0).unwrap(), 7u64);
1060/// ```
1061#[derive(Clone, Copy, Debug)]
1062pub struct Poisson {
1063    /// Rate parameter λ (must be positive). Mean and variance of the distribution.
1064    lambda: f64,
1065}
1066impl Poisson {
1067    /// Create a new Poisson distribution with validated parameters.
1068    pub fn new(lambda: f64) -> crate::error::FugueResult<Self> {
1069        if lambda <= 0.0 || !lambda.is_finite() {
1070            return Err(crate::error::FugueError::invalid_parameters(
1071                "Poisson",
1072                "Rate parameter lambda must be positive and finite",
1073                crate::error::ErrorCode::InvalidRate,
1074            )
1075            .with_context("lambda", format!("{}", lambda))
1076            .with_context("expected", "> 0.0 and finite"));
1077        }
1078        Ok(Poisson { lambda })
1079    }
1080
1081    /// Get the rate parameter.
1082    pub fn lambda(&self) -> f64 {
1083        self.lambda
1084    }
1085}
1086impl Distribution<u64> for Poisson {
1087    fn sample(&self, rng: &mut dyn RngCore) -> u64 {
1088        if self.lambda <= 0.0 || !self.lambda.is_finite() {
1089            return 0;
1090        }
1091        RDPoisson::new(self.lambda).unwrap().sample(rng) as u64
1092    }
1093    fn log_prob(&self, x: &u64) -> LogF64 {
1094        // Parameter validation
1095        if self.lambda <= 0.0 || !self.lambda.is_finite() {
1096            return f64::NEG_INFINITY;
1097        }
1098
1099        let k = *x;
1100
1101        // Handle extreme cases
1102        if self.lambda > 700.0 && k == 0 {
1103            return -self.lambda; // Direct computation to avoid lgamma issues
1104        }
1105
1106        // Numerically stable computation
1107        // log Poisson(k; λ) = k*ln(λ) - λ - ln(k!)
1108        let k_f64 = k as f64;
1109        let log_lambda = self.lambda.ln();
1110        let log_factorial = libm::lgamma(k_f64 + 1.0);
1111
1112        k_f64 * log_lambda - self.lambda - log_factorial
1113    }
1114    fn clone_box(&self) -> Box<dyn Distribution<u64>> {
1115        Box::new(*self)
1116    }
1117}
1118
1119#[cfg(test)]
1120mod tests {
1121    use super::*;
1122    use rand::rngs::StdRng;
1123    use rand::SeedableRng;
1124
1125    #[test]
1126    fn normal_constructor_and_log_prob() {
1127        assert!(Normal::new(0.0, 1.0).is_ok());
1128        assert!(Normal::new(f64::NAN, 1.0).is_err());
1129        assert!(Normal::new(0.0, 0.0).is_err());
1130
1131        let n = Normal::new(0.0, 1.0).unwrap();
1132        assert!(n.log_prob(&0.0).is_finite());
1133        assert_eq!(n.log_prob(&f64::INFINITY), f64::NEG_INFINITY);
1134    }
1135
1136    #[test]
1137    fn uniform_support_and_log_prob() {
1138        assert!(Uniform::new(0.0, 1.0).is_ok());
1139        assert!(Uniform::new(1.0, 0.0).is_err());
1140        let u = Uniform::new(-2.0, 2.0).unwrap();
1141        // Inside support
1142        let lp0 = u.log_prob(&0.0);
1143        assert!(lp0.is_finite());
1144        // Outside support
1145        assert_eq!(u.log_prob(&2.0), f64::NEG_INFINITY);
1146        assert_eq!(u.log_prob(&-2.1), f64::NEG_INFINITY);
1147    }
1148
1149    #[test]
1150    fn lognormal_validation() {
1151        assert!(LogNormal::new(0.0, 1.0).is_ok());
1152        assert!(LogNormal::new(0.0, 0.0).is_err());
1153        let ln = LogNormal::new(0.0, 1.0).unwrap();
1154        assert_eq!(ln.log_prob(&0.0), f64::NEG_INFINITY);
1155        assert!(ln.log_prob(&1.0).is_finite());
1156    }
1157
1158    #[test]
1159    fn exponential_validation() {
1160        assert!(Exponential::new(1.0).is_ok());
1161        assert!(Exponential::new(0.0).is_err());
1162        let e = Exponential::new(2.0).unwrap();
1163        assert_eq!(e.log_prob(&-1.0), f64::NEG_INFINITY);
1164        assert!((e.log_prob(&0.0) - (2.0f64).ln()).abs() < 1e-12);
1165    }
1166
1167    #[test]
1168    fn bernoulli_validation() {
1169        assert!(Bernoulli::new(0.5).is_ok());
1170        assert!(Bernoulli::new(-0.1).is_err());
1171        let b = Bernoulli::new(0.25).unwrap();
1172        assert!((b.log_prob(&true) - (0.25f64).ln()).abs() < 1e-12);
1173        assert!((b.log_prob(&false) - (0.75f64).ln()).abs() < 1e-12);
1174    }
1175
1176    #[test]
1177    fn categorical_validation_and_log_prob() {
1178        assert!(Categorical::new(vec![0.5, 0.5]).is_ok());
1179        assert!(Categorical::new(vec![]).is_err());
1180        assert!(Categorical::new(vec![0.6, 0.5]).is_err());
1181
1182        let c = Categorical::new(vec![0.2, 0.8]).unwrap();
1183        assert!((c.log_prob(&1) - (0.8f64).ln()).abs() < 1e-12);
1184        assert_eq!(c.log_prob(&2), f64::NEG_INFINITY);
1185    }
1186
1187    #[test]
1188    fn beta_validation_and_support() {
1189        assert!(Beta::new(2.0, 3.0).is_ok());
1190        assert!(Beta::new(0.0, 1.0).is_err());
1191        let b = Beta::new(2.0, 5.0).unwrap();
1192        assert_eq!(b.log_prob(&0.0), f64::NEG_INFINITY);
1193        assert_eq!(b.log_prob(&1.0), f64::NEG_INFINITY);
1194        assert!(b.log_prob(&0.5).is_finite());
1195    }
1196
1197    #[test]
1198    fn gamma_validation_and_support() {
1199        assert!(Gamma::new(1.5, 2.0).is_ok());
1200        assert!(Gamma::new(0.0, 2.0).is_err());
1201        assert!(Gamma::new(1.0, 0.0).is_err());
1202        let g = Gamma::new(2.0, 1.0).unwrap();
1203        assert_eq!(g.log_prob(&-1.0), f64::NEG_INFINITY);
1204        assert!(g.log_prob(&1.0).is_finite());
1205    }
1206
1207    #[test]
1208    fn binomial_validation_and_log_prob() {
1209        assert!(Binomial::new(10, 0.5).is_ok());
1210        assert!(Binomial::new(10, 1.5).is_err());
1211        let bi = Binomial::new(5, 0.3).unwrap();
1212        assert_eq!(bi.log_prob(&6), f64::NEG_INFINITY); // k > n
1213        assert!(bi.log_prob(&3).is_finite());
1214    }
1215
1216    #[test]
1217    fn poisson_validation_and_log_prob() {
1218        assert!(Poisson::new(1.0).is_ok());
1219        assert!(Poisson::new(0.0).is_err());
1220        let p = Poisson::new(3.0).unwrap();
1221        assert!(p.log_prob(&0).is_finite());
1222        assert!(p.log_prob(&5).is_finite());
1223    }
1224
1225    #[test]
1226    fn sampling_basic_sanity() {
1227        let mut rng = StdRng::seed_from_u64(42);
1228        let n = Normal::new(0.0, 1.0).unwrap();
1229        let x = n.sample(&mut rng);
1230        assert!(x.is_finite());
1231
1232        let u = Uniform::new(-1.0, 2.0).unwrap();
1233        let y = u.sample(&mut rng);
1234        assert!((-1.0..2.0).contains(&y));
1235
1236        let b = Bernoulli::new(0.7).unwrap();
1237        let _z = b.sample(&mut rng);
1238    }
1239
1240    #[test]
1241    fn categorical_uniform_constructor() {
1242        let cu = Categorical::uniform(4).unwrap();
1243        assert_eq!(cu.len(), 4);
1244        for &p in cu.probs() {
1245            assert!((p - 0.25).abs() < 1e-12);
1246        }
1247    }
1248}