Skip to main content

cyanea_stats/
bayesian.rs

1//! Bayesian conjugate prior distributions.
2//!
3//! Provides four conjugate prior–likelihood pairs commonly used in
4//! bioinformatics:
5//!
6//! - [`Beta`] — conjugate to binomial likelihood
7//! - [`Gamma`] — conjugate to Poisson likelihood
8//! - [`NormalConjugate`] — normal prior for normal likelihood (known variance)
9//! - [`Dirichlet`] — conjugate to multinomial likelihood
10
11use crate::distribution::{betai, gammainc, ln_gamma, Distribution};
12use cyanea_core::{CyaneaError, Result};
13
14// ── Beta distribution ────────────────────────────────────────────────────
15
16/// Beta distribution, conjugate prior for binomial likelihood.
17///
18/// After observing `s` successes in `n` trials, the posterior is
19/// `Beta(α + s, β + n − s)`.
20#[derive(Debug, Clone, Copy)]
21pub struct Beta {
22    alpha: f64,
23    beta: f64,
24}
25
26impl Beta {
27    /// Create a Beta distribution with shape parameters `alpha` and `beta`.
28    ///
29    /// Both must be positive.
30    pub fn new(alpha: f64, beta: f64) -> Result<Self> {
31        if alpha <= 0.0 || beta <= 0.0 {
32            return Err(CyaneaError::InvalidInput(
33                "Beta: alpha and beta must be positive".into(),
34            ));
35        }
36        Ok(Self { alpha, beta })
37    }
38
39    /// Alpha parameter.
40    pub fn alpha(&self) -> f64 {
41        self.alpha
42    }
43
44    /// Beta parameter.
45    pub fn beta(&self) -> f64 {
46        self.beta
47    }
48
49    /// Compute the posterior after observing binomial data.
50    pub fn update_binomial(&self, successes: u64, trials: u64) -> Self {
51        Self {
52            alpha: self.alpha + successes as f64,
53            beta: self.beta + (trials - successes) as f64,
54        }
55    }
56}
57
58impl Distribution for Beta {
59    fn pdf(&self, x: f64) -> f64 {
60        if x <= 0.0 || x >= 1.0 {
61            return 0.0;
62        }
63        let ln_beta_fn = ln_gamma(self.alpha) + ln_gamma(self.beta)
64            - ln_gamma(self.alpha + self.beta);
65        let ln_pdf = (self.alpha - 1.0) * x.ln()
66            + (self.beta - 1.0) * (1.0 - x).ln()
67            - ln_beta_fn;
68        ln_pdf.exp()
69    }
70
71    fn cdf(&self, x: f64) -> f64 {
72        if x <= 0.0 {
73            return 0.0;
74        }
75        if x >= 1.0 {
76            return 1.0;
77        }
78        betai(self.alpha, self.beta, x).unwrap_or(0.0)
79    }
80
81    fn mean(&self) -> f64 {
82        self.alpha / (self.alpha + self.beta)
83    }
84
85    fn variance(&self) -> f64 {
86        let ab = self.alpha + self.beta;
87        (self.alpha * self.beta) / (ab * ab * (ab + 1.0))
88    }
89}
90
91// ── Gamma distribution ───────────────────────────────────────────────────
92
93/// Gamma distribution (shape/rate parameterization), conjugate prior for
94/// Poisson likelihood.
95///
96/// After observing a count `c` (one observation), the posterior is
97/// `Gamma(shape + c, rate + 1)`.
98#[derive(Debug, Clone, Copy)]
99pub struct Gamma {
100    shape: f64,
101    rate: f64,
102}
103
104impl Gamma {
105    /// Create a Gamma distribution with given `shape` (α) and `rate` (β).
106    ///
107    /// Both must be positive.
108    pub fn new(shape: f64, rate: f64) -> Result<Self> {
109        if shape <= 0.0 || rate <= 0.0 {
110            return Err(CyaneaError::InvalidInput(
111                "Gamma: shape and rate must be positive".into(),
112            ));
113        }
114        Ok(Self { shape, rate })
115    }
116
117    /// Shape parameter (α).
118    pub fn shape(&self) -> f64 {
119        self.shape
120    }
121
122    /// Rate parameter (β).
123    pub fn rate(&self) -> f64 {
124        self.rate
125    }
126
127    /// Posterior after observing a single Poisson count.
128    pub fn update_poisson(&self, count: u64) -> Self {
129        Self {
130            shape: self.shape + count as f64,
131            rate: self.rate + 1.0,
132        }
133    }
134
135    /// Posterior after observing multiple Poisson counts.
136    pub fn update_poisson_batch(&self, counts: &[u64]) -> Self {
137        let total: u64 = counts.iter().sum();
138        Self {
139            shape: self.shape + total as f64,
140            rate: self.rate + counts.len() as f64,
141        }
142    }
143}
144
145impl Distribution for Gamma {
146    fn pdf(&self, x: f64) -> f64 {
147        if x <= 0.0 {
148            return 0.0;
149        }
150        let ln_pdf = self.shape * self.rate.ln() - ln_gamma(self.shape)
151            + (self.shape - 1.0) * x.ln()
152            - self.rate * x;
153        ln_pdf.exp()
154    }
155
156    fn cdf(&self, x: f64) -> f64 {
157        if x <= 0.0 {
158            return 0.0;
159        }
160        // P(a, rate * x) but gammainc takes shape and x
161        // For Gamma(shape, rate): CDF(x) = P(shape, rate * x) = gammainc(shape, rate * x)
162        gammainc(self.shape, self.rate * x).unwrap_or(0.0)
163    }
164
165    fn mean(&self) -> f64 {
166        self.shape / self.rate
167    }
168
169    fn variance(&self) -> f64 {
170        self.shape / (self.rate * self.rate)
171    }
172}
173
174// ── Normal conjugate ─────────────────────────────────────────────────────
175
176/// Normal prior for a normal likelihood with known observation variance.
177///
178/// Uses the precision (inverse variance) formulation for numerically
179/// stable Bayesian updates.
180#[derive(Debug, Clone, Copy)]
181pub struct NormalConjugate {
182    prior_mu: f64,
183    prior_var: f64,
184    obs_var: f64,
185}
186
187impl NormalConjugate {
188    /// Create a normal conjugate prior.
189    ///
190    /// - `prior_mu`: prior mean
191    /// - `prior_var`: prior variance (must be positive)
192    /// - `obs_var`: known observation variance (must be positive)
193    pub fn new(prior_mu: f64, prior_var: f64, obs_var: f64) -> Result<Self> {
194        if prior_var <= 0.0 {
195            return Err(CyaneaError::InvalidInput(
196                "NormalConjugate: prior_var must be positive".into(),
197            ));
198        }
199        if obs_var <= 0.0 {
200            return Err(CyaneaError::InvalidInput(
201                "NormalConjugate: obs_var must be positive".into(),
202            ));
203        }
204        Ok(Self {
205            prior_mu,
206            prior_var,
207            obs_var,
208        })
209    }
210
211    /// Update with a single observation.
212    pub fn update(&self, observation: f64) -> Self {
213        let prior_prec = 1.0 / self.prior_var;
214        let obs_prec = 1.0 / self.obs_var;
215        let post_prec = prior_prec + obs_prec;
216        let post_var = 1.0 / post_prec;
217        let post_mu = (prior_prec * self.prior_mu + obs_prec * observation) / post_prec;
218        Self {
219            prior_mu: post_mu,
220            prior_var: post_var,
221            obs_var: self.obs_var,
222        }
223    }
224
225    /// Update with a batch of observations.
226    pub fn update_batch(&self, observations: &[f64]) -> Self {
227        let n = observations.len() as f64;
228        if n == 0.0 {
229            return *self;
230        }
231        let obs_mean: f64 = observations.iter().sum::<f64>() / n;
232        let prior_prec = 1.0 / self.prior_var;
233        let obs_prec = n / self.obs_var;
234        let post_prec = prior_prec + obs_prec;
235        let post_var = 1.0 / post_prec;
236        let post_mu = (prior_prec * self.prior_mu + obs_prec * obs_mean) / post_prec;
237        Self {
238            prior_mu: post_mu,
239            prior_var: post_var,
240            obs_var: self.obs_var,
241        }
242    }
243
244    /// Posterior mean.
245    pub fn posterior_mean(&self) -> f64 {
246        self.prior_mu
247    }
248
249    /// Posterior variance.
250    pub fn posterior_variance(&self) -> f64 {
251        self.prior_var
252    }
253}
254
255// ── Dirichlet distribution ───────────────────────────────────────────────
256
257/// Dirichlet distribution, conjugate prior for multinomial likelihood.
258///
259/// After observing counts `c₁, ..., cₖ`, the posterior is
260/// `Dirichlet(α₁ + c₁, ..., αₖ + cₖ)`.
261#[derive(Debug, Clone)]
262pub struct Dirichlet {
263    alpha: Vec<f64>,
264}
265
266impl Dirichlet {
267    /// Create a Dirichlet distribution with concentration parameters `alpha`.
268    ///
269    /// All elements must be positive and the vector must have at least 2 elements.
270    pub fn new(alpha: Vec<f64>) -> Result<Self> {
271        if alpha.len() < 2 {
272            return Err(CyaneaError::InvalidInput(
273                "Dirichlet: need at least 2 categories".into(),
274            ));
275        }
276        if alpha.iter().any(|&a| a <= 0.0) {
277            return Err(CyaneaError::InvalidInput(
278                "Dirichlet: all alpha values must be positive".into(),
279            ));
280        }
281        Ok(Self { alpha })
282    }
283
284    /// Create a symmetric Dirichlet with `k` categories, each with
285    /// concentration `alpha`.
286    pub fn symmetric(k: usize, alpha: f64) -> Result<Self> {
287        if k < 2 {
288            return Err(CyaneaError::InvalidInput(
289                "Dirichlet: need at least 2 categories".into(),
290            ));
291        }
292        if alpha <= 0.0 {
293            return Err(CyaneaError::InvalidInput(
294                "Dirichlet: alpha must be positive".into(),
295            ));
296        }
297        Ok(Self {
298            alpha: vec![alpha; k],
299        })
300    }
301
302    /// Concentration parameters.
303    pub fn alpha(&self) -> &[f64] {
304        &self.alpha
305    }
306
307    /// Posterior after observing multinomial counts.
308    ///
309    /// # Panics
310    ///
311    /// Panics if `counts.len() != self.alpha.len()`.
312    pub fn update_multinomial(&self, counts: &[u64]) -> Self {
313        assert_eq!(
314            counts.len(),
315            self.alpha.len(),
316            "counts length must match alpha length"
317        );
318        Self {
319            alpha: self
320                .alpha
321                .iter()
322                .zip(counts.iter())
323                .map(|(&a, &c)| a + c as f64)
324                .collect(),
325        }
326    }
327
328    /// Expected value (mean) of the Dirichlet: `E[Xᵢ] = αᵢ / Σα`.
329    pub fn mean(&self) -> Vec<f64> {
330        let sum: f64 = self.alpha.iter().sum();
331        self.alpha.iter().map(|&a| a / sum).collect()
332    }
333
334    /// Variance of each component: `Var[Xᵢ] = αᵢ(Σα − αᵢ) / (Σα² (Σα + 1))`.
335    pub fn variance(&self) -> Vec<f64> {
336        let sum: f64 = self.alpha.iter().sum();
337        let denom = sum * sum * (sum + 1.0);
338        self.alpha.iter().map(|&a| a * (sum - a) / denom).collect()
339    }
340
341    /// Log-PDF of the Dirichlet at point `x`.
342    ///
343    /// # Errors
344    ///
345    /// Returns an error if `x.len() != alpha.len()` or if values don't
346    /// sum to approximately 1.
347    pub fn ln_pdf(&self, x: &[f64]) -> Result<f64> {
348        if x.len() != self.alpha.len() {
349            return Err(CyaneaError::InvalidInput(
350                "Dirichlet::ln_pdf: x length must match alpha length".into(),
351            ));
352        }
353        let sum: f64 = x.iter().sum();
354        if (sum - 1.0).abs() > 1e-6 {
355            return Err(CyaneaError::InvalidInput(
356                "Dirichlet::ln_pdf: x must sum to 1".into(),
357            ));
358        }
359
360        let alpha_sum: f64 = self.alpha.iter().sum();
361        let mut ln_b = -ln_gamma(alpha_sum);
362        for &a in &self.alpha {
363            ln_b += ln_gamma(a);
364        }
365
366        let mut result = -ln_b;
367        for (xi, &ai) in x.iter().zip(self.alpha.iter()) {
368            if *xi <= 0.0 {
369                return Err(CyaneaError::InvalidInput(
370                    "Dirichlet::ln_pdf: all x values must be positive".into(),
371                ));
372            }
373            result += (ai - 1.0) * xi.ln();
374        }
375
376        Ok(result)
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    const TOL: f64 = 1e-6;
385
386    // ── Beta tests ────────────────────────────────────────────────────
387
388    #[test]
389    fn beta_uniform_prior() {
390        let prior = Beta::new(1.0, 1.0).unwrap();
391        assert!((prior.mean() - 0.5).abs() < TOL);
392    }
393
394    #[test]
395    fn beta_conjugacy() {
396        // Beta(1,1) + 3 successes in 10 trials = Beta(4, 8)
397        let prior = Beta::new(1.0, 1.0).unwrap();
398        let post = prior.update_binomial(3, 10);
399        assert!((post.alpha() - 4.0).abs() < TOL);
400        assert!((post.beta() - 8.0).abs() < TOL);
401        assert!((post.mean() - 4.0 / 12.0).abs() < TOL);
402    }
403
404    #[test]
405    fn beta_pdf_at_mode() {
406        // Beta(2, 5): mode = (2-1)/(2+5-2) = 1/5 = 0.2
407        let b = Beta::new(2.0, 5.0).unwrap();
408        let pdf_at_mode = b.pdf(0.2);
409        // Should be near the maximum
410        assert!(pdf_at_mode > b.pdf(0.1));
411        assert!(pdf_at_mode > b.pdf(0.5));
412    }
413
414    #[test]
415    fn beta_cdf_boundaries() {
416        let b = Beta::new(2.0, 3.0).unwrap();
417        assert_eq!(b.cdf(0.0), 0.0);
418        assert!((b.cdf(1.0) - 1.0).abs() < TOL);
419    }
420
421    #[test]
422    fn beta_cdf_midpoint() {
423        // Beta(1,1) is uniform, so CDF(0.5) = 0.5
424        let b = Beta::new(1.0, 1.0).unwrap();
425        assert!((b.cdf(0.5) - 0.5).abs() < TOL);
426    }
427
428    #[test]
429    fn beta_invalid() {
430        assert!(Beta::new(0.0, 1.0).is_err());
431        assert!(Beta::new(1.0, -1.0).is_err());
432    }
433
434    // ── Gamma tests ──────────────────────────────────────────────────
435
436    #[test]
437    fn gamma_mean_variance() {
438        let g = Gamma::new(3.0, 2.0).unwrap();
439        assert!((g.mean() - 1.5).abs() < TOL);
440        assert!((g.variance() - 0.75).abs() < TOL);
441    }
442
443    #[test]
444    fn gamma_conjugacy_poisson() {
445        // Gamma(2, 1) + observe count=5 → Gamma(7, 2)
446        let prior = Gamma::new(2.0, 1.0).unwrap();
447        let post = prior.update_poisson(5);
448        assert!((post.shape() - 7.0).abs() < TOL);
449        assert!((post.rate() - 2.0).abs() < TOL);
450    }
451
452    #[test]
453    fn gamma_conjugacy_batch() {
454        // Gamma(2, 1) + observe [3, 5, 2] → Gamma(2+10, 1+3) = Gamma(12, 4)
455        let prior = Gamma::new(2.0, 1.0).unwrap();
456        let post = prior.update_poisson_batch(&[3, 5, 2]);
457        assert!((post.shape() - 12.0).abs() < TOL);
458        assert!((post.rate() - 4.0).abs() < TOL);
459    }
460
461    #[test]
462    fn gamma_cdf() {
463        // Gamma(1, 1) is Exponential(1): CDF(x) = 1 - e^{-x}
464        let g = Gamma::new(1.0, 1.0).unwrap();
465        let x = 2.0;
466        let expected = 1.0 - (-x as f64).exp();
467        assert!((g.cdf(x) - expected).abs() < 1e-8);
468    }
469
470    #[test]
471    fn gamma_invalid() {
472        assert!(Gamma::new(0.0, 1.0).is_err());
473        assert!(Gamma::new(1.0, 0.0).is_err());
474    }
475
476    // ── NormalConjugate tests ────────────────────────────────────────
477
478    #[test]
479    fn normal_conjugate_single_update() {
480        let prior = NormalConjugate::new(0.0, 1.0, 1.0).unwrap();
481        let post = prior.update(2.0);
482        // Posterior mean = (0/1 + 2/1) / (1/1 + 1/1) = 2/2 = 1.0
483        assert!((post.posterior_mean() - 1.0).abs() < TOL);
484        // Posterior variance = 1/(1+1) = 0.5
485        assert!((post.posterior_variance() - 0.5).abs() < TOL);
486    }
487
488    #[test]
489    fn normal_conjugate_batch_update() {
490        let prior = NormalConjugate::new(0.0, 1.0, 1.0).unwrap();
491        let post = prior.update_batch(&[2.0, 4.0]);
492        // obs_mean = 3.0, n = 2
493        // prior_prec = 1, obs_prec = 2/1 = 2
494        // post_prec = 3, post_var = 1/3
495        // post_mu = (0 + 2*3)/3 = 2.0
496        assert!((post.posterior_mean() - 2.0).abs() < TOL);
497        assert!((post.posterior_variance() - 1.0 / 3.0).abs() < TOL);
498    }
499
500    #[test]
501    fn normal_conjugate_empty_batch() {
502        let prior = NormalConjugate::new(5.0, 2.0, 1.0).unwrap();
503        let post = prior.update_batch(&[]);
504        assert!((post.posterior_mean() - 5.0).abs() < TOL);
505        assert!((post.posterior_variance() - 2.0).abs() < TOL);
506    }
507
508    #[test]
509    fn normal_conjugate_precision_shrinkage() {
510        // With very precise prior, observation barely moves the mean
511        let prior = NormalConjugate::new(0.0, 0.01, 100.0).unwrap();
512        let post = prior.update(100.0);
513        // prior_prec = 100, obs_prec = 0.01
514        // post_mu ≈ 0.0 (barely shifts)
515        assert!(post.posterior_mean().abs() < 0.02);
516    }
517
518    #[test]
519    fn normal_conjugate_invalid() {
520        assert!(NormalConjugate::new(0.0, 0.0, 1.0).is_err());
521        assert!(NormalConjugate::new(0.0, 1.0, 0.0).is_err());
522    }
523
524    // ── Dirichlet tests ──────────────────────────────────────────────
525
526    #[test]
527    fn dirichlet_symmetric_mean() {
528        let d = Dirichlet::symmetric(4, 1.0).unwrap();
529        let mean = d.mean();
530        assert_eq!(mean.len(), 4);
531        for m in &mean {
532            assert!((m - 0.25).abs() < TOL);
533        }
534    }
535
536    #[test]
537    fn dirichlet_conjugacy_multinomial() {
538        let prior = Dirichlet::symmetric(3, 1.0).unwrap();
539        let post = prior.update_multinomial(&[10, 5, 15]);
540        let expected = [11.0, 6.0, 16.0];
541        for (a, e) in post.alpha().iter().zip(expected.iter()) {
542            assert!((a - e).abs() < TOL);
543        }
544    }
545
546    #[test]
547    fn dirichlet_ln_pdf() {
548        // Dirichlet(1,1,1) is uniform on the simplex: ln_pdf = ln(Γ(3)/Γ(1)^3) = ln(2)
549        let d = Dirichlet::symmetric(3, 1.0).unwrap();
550        let ln_pdf = d.ln_pdf(&[1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]).unwrap();
551        assert!((ln_pdf - 2.0_f64.ln()).abs() < 1e-6);
552    }
553
554    #[test]
555    fn dirichlet_invalid() {
556        assert!(Dirichlet::new(vec![1.0]).is_err()); // too few
557        assert!(Dirichlet::new(vec![1.0, -1.0]).is_err());
558        assert!(Dirichlet::symmetric(1, 1.0).is_err());
559        assert!(Dirichlet::symmetric(3, 0.0).is_err());
560    }
561
562    #[test]
563    fn dirichlet_ln_pdf_invalid() {
564        let d = Dirichlet::symmetric(3, 1.0).unwrap();
565        assert!(d.ln_pdf(&[0.5, 0.5]).is_err()); // wrong length
566        assert!(d.ln_pdf(&[0.5, 0.3, 0.1]).is_err()); // doesn't sum to 1
567    }
568}