Skip to main content

scirs2_stats/bayesian/
conjugate.rs

1//! Conjugate prior distributions for Bayesian inference
2//!
3//! This module implements conjugate prior-posterior relationships for efficient Bayesian updates.
4
5use crate::error::{StatsError, StatsResult as Result};
6use scirs2_core::ndarray::{Array1, ArrayView1};
7use scirs2_core::validation::*;
8use statrs::statistics::Statistics;
9
10/// Beta-Binomial conjugate pair
11///
12/// Prior: Beta(α, β)
13/// Likelihood: Binomial(n, p)
14/// Posterior: Beta(α + successes, β + failures)
15#[derive(Debug, Clone)]
16pub struct BetaBinomial {
17    /// Alpha parameter of the Beta prior
18    pub alpha: f64,
19    /// Beta parameter of the Beta prior  
20    pub beta: f64,
21}
22
23impl BetaBinomial {
24    /// Create a new Beta-Binomial conjugate prior
25    pub fn new(alpha: f64, beta: f64) -> Result<Self> {
26        check_positive(alpha, "alpha")?;
27        check_positive(beta, "beta")?;
28        Ok(Self { alpha, beta })
29    }
30
31    /// Update the prior with observed data
32    ///
33    /// # Arguments
34    /// * `successes` - Number of successes observed
35    /// * `failures` - Number of failures observed
36    ///
37    /// # Returns
38    /// Updated BetaBinomial with posterior parameters
39    pub fn update(&self, successes: usize, failures: usize) -> Self {
40        Self {
41            alpha: self.alpha + successes as f64,
42            beta: self.beta + failures as f64,
43        }
44    }
45
46    /// Compute the posterior mean
47    pub fn posterior_mean(&self) -> Result<f64> {
48        let total = self.alpha + self.beta;
49        if total.abs() < f64::EPSILON {
50            return Err(StatsError::domain(
51                "Cannot compute posterior mean: alpha + beta too close to zero",
52            ));
53        }
54        Ok(self.alpha / total)
55    }
56
57    /// Compute the posterior variance
58    pub fn posterior_variance(&self) -> Result<f64> {
59        let total = self.alpha + self.beta;
60        if total.abs() < f64::EPSILON {
61            return Err(StatsError::domain(
62                "Cannot compute posterior variance: alpha + beta too close to zero",
63            ));
64        }
65        let denominator = total * total * (total + 1.0);
66        if denominator.abs() < f64::EPSILON {
67            return Err(StatsError::domain(
68                "Cannot compute posterior variance: denominator too close to zero",
69            ));
70        }
71        Ok((self.alpha * self.beta) / denominator)
72    }
73
74    /// Compute the posterior mode (MAP estimate)
75    pub fn posterior_mode(&self) -> Result<Option<f64>> {
76        if self.alpha > 1.0 && self.beta > 1.0 {
77            let denominator = self.alpha + self.beta - 2.0;
78            if denominator.abs() < f64::EPSILON {
79                return Err(StatsError::domain(
80                    "Cannot compute posterior mode: alpha + beta - 2 too close to zero",
81                ));
82            }
83            Ok(Some((self.alpha - 1.0) / denominator))
84        } else {
85            Ok(None)
86        }
87    }
88
89    /// Compute credible interval
90    pub fn credible_interval(&self, confidence: f64) -> Result<(f64, f64)> {
91        check_probability(confidence, "confidence")?;
92
93        // Use beta distribution quantiles
94        use crate::distributions::beta::Beta;
95        let dist = Beta::new(self.alpha, self.beta, 0.0, 1.0)?;
96
97        let alpha_level = (1.0 - confidence) / 2.0;
98        Ok((dist.ppf(alpha_level)?, dist.ppf(1.0 - alpha_level)?))
99    }
100}
101
102/// Gamma-Poisson (Gamma-Negative Binomial) conjugate pair
103///
104/// Prior: Gamma(α, β)
105/// Likelihood: Poisson(λ)
106/// Posterior: Gamma(α + sum(data), β + n)
107#[derive(Debug, Clone)]
108pub struct GammaPoisson {
109    /// Shape parameter of the Gamma prior
110    pub alpha: f64,
111    /// Rate parameter of the Gamma prior
112    pub beta: f64,
113}
114
115impl GammaPoisson {
116    /// Create a new Gamma-Poisson conjugate prior
117    pub fn new(alpha: f64, beta: f64) -> Result<Self> {
118        check_positive(alpha, "alpha")?;
119        check_positive(beta, "beta")?;
120        Ok(Self { alpha, beta })
121    }
122
123    /// Update the prior with observed count data
124    ///
125    /// # Arguments
126    /// * `data` - Array of observed counts
127    ///
128    /// # Returns
129    /// Updated GammaPoisson with posterior parameters
130    pub fn update(&self, data: ArrayView1<f64>) -> Result<Self> {
131        checkarray_finite(&data, "data")?;
132        let sum: f64 = data.sum();
133        let n = data.len() as f64;
134
135        Ok(Self {
136            alpha: self.alpha + sum,
137            beta: self.beta + n,
138        })
139    }
140
141    /// Compute the posterior mean
142    pub fn posterior_mean(&self) -> Result<f64> {
143        if self.beta.abs() < f64::EPSILON {
144            return Err(StatsError::domain(
145                "Cannot compute posterior mean: beta too close to zero",
146            ));
147        }
148        Ok(self.alpha / self.beta)
149    }
150
151    /// Compute the posterior variance
152    pub fn posterior_variance(&self) -> Result<f64> {
153        if self.beta.abs() < f64::EPSILON {
154            return Err(StatsError::domain(
155                "Cannot compute posterior variance: beta too close to zero",
156            ));
157        }
158        Ok(self.alpha / (self.beta * self.beta))
159    }
160
161    /// Compute the posterior mode (MAP estimate)
162    pub fn posterior_mode(&self) -> Result<Option<f64>> {
163        if self.alpha >= 1.0 {
164            if self.beta.abs() < f64::EPSILON {
165                return Err(StatsError::domain(
166                    "Cannot compute posterior mode: beta too close to zero",
167                ));
168            }
169            Ok(Some((self.alpha - 1.0) / self.beta))
170        } else {
171            Ok(None)
172        }
173    }
174
175    /// Compute credible interval
176    pub fn credible_interval(&self, confidence: f64) -> Result<(f64, f64)> {
177        check_probability(confidence, "confidence")?;
178
179        // Use gamma distribution quantiles
180        use crate::distributions::gamma::Gamma;
181        let dist = Gamma::new(self.alpha, 1.0 / self.beta, 0.0)?; // Note: using scale parameterization
182
183        let alpha_level = (1.0 - confidence) / 2.0;
184        Ok((dist.ppf(alpha_level)?, dist.ppf(1.0 - alpha_level)?))
185    }
186}
187
188/// Normal-Normal conjugate pair with known variance
189///
190/// Prior: Normal(μ₀, σ₀²)
191/// Likelihood: Normal(μ, σ²) with known σ²
192/// Posterior: Normal(μₙ, σₙ²)
193#[derive(Debug, Clone)]
194pub struct NormalKnownVariance {
195    /// Prior mean
196    pub prior_mean: f64,
197    /// Prior variance
198    pub prior_variance: f64,
199    /// Known data variance
200    pub data_variance: f64,
201}
202
203impl NormalKnownVariance {
204    /// Create a new Normal conjugate prior with known data variance
205    pub fn new(prior_mean: f64, prior_variance: f64, data_variance: f64) -> Result<Self> {
206        check_positive(prior_variance, "prior_variance")?;
207        check_positive(data_variance, "data_variance")?;
208        Ok(Self {
209            prior_mean,
210            prior_variance,
211            data_variance,
212        })
213    }
214
215    /// Update the prior with observed data
216    ///
217    /// # Arguments
218    /// * `data` - Array of observed values
219    ///
220    /// # Returns
221    /// Updated NormalKnownVariance with posterior parameters
222    pub fn update(&self, data: ArrayView1<f64>) -> Result<Self> {
223        checkarray_finite(&data, "data")?;
224        let n = data.len() as f64;
225        let data_mean = data.mean();
226
227        if self.prior_variance.abs() < f64::EPSILON {
228            return Err(StatsError::domain(
229                "Cannot update: prior_variance too close to zero",
230            ));
231        }
232        if self.data_variance.abs() < f64::EPSILON {
233            return Err(StatsError::domain(
234                "Cannot update: data_variance too close to zero",
235            ));
236        }
237
238        let precision_prior = 1.0 / self.prior_variance;
239        let precisiondata = n / self.data_variance;
240        let precision_posterior = precision_prior + precisiondata;
241
242        if precision_posterior.abs() < f64::EPSILON {
243            return Err(StatsError::domain(
244                "Cannot update: precision_posterior too close to zero",
245            ));
246        }
247
248        let posterior_variance = 1.0 / precision_posterior;
249        let posterior_mean =
250            (precision_prior * self.prior_mean + precisiondata * data_mean) / precision_posterior;
251
252        Ok(Self {
253            prior_mean: posterior_mean,
254            prior_variance: posterior_variance,
255            data_variance: self.data_variance,
256        })
257    }
258
259    /// Compute the posterior mean
260    pub fn posterior_mean(&self) -> f64 {
261        self.prior_mean
262    }
263
264    /// Compute the posterior variance
265    pub fn posterior_variance(&self) -> f64 {
266        self.prior_variance
267    }
268
269    /// Compute credible interval
270    pub fn credible_interval(&self, confidence: f64) -> Result<(f64, f64)> {
271        check_probability(confidence, "confidence")?;
272
273        // Use normal distribution quantiles
274        use crate::distributions::normal::Normal;
275        if self.prior_variance < 0.0 {
276            return Err(StatsError::domain(
277                "Cannot compute credible interval: prior_variance must be non-negative",
278            ));
279        }
280        let dist = Normal::new(self.prior_mean, self.prior_variance.sqrt())?;
281
282        let alpha_level = (1.0 - confidence) / 2.0;
283        Ok((dist.ppf(alpha_level)?, dist.ppf(1.0 - alpha_level)?))
284    }
285
286    /// Compute the predictive distribution parameters
287    pub fn predictive_params(&self) -> (f64, f64) {
288        (self.prior_mean, self.prior_variance + self.data_variance)
289    }
290}
291
292/// Dirichlet-Multinomial conjugate pair
293///
294/// Prior: Dirichlet(α)
295/// Likelihood: Multinomial(n, p)
296/// Posterior: Dirichlet(α + counts)
297#[derive(Debug, Clone)]
298pub struct DirichletMultinomial {
299    /// Concentration parameters of the Dirichlet prior
300    pub alpha: Array1<f64>,
301}
302
303impl DirichletMultinomial {
304    /// Create a new Dirichlet-Multinomial conjugate prior
305    pub fn new(alpha: Array1<f64>) -> Result<Self> {
306        checkarray_finite(&alpha, "alpha")?;
307        for &a in alpha.iter() {
308            check_positive(a, "_alpha element")?;
309        }
310        Ok(Self { alpha })
311    }
312
313    /// Create uniform prior with given dimension
314    pub fn uniform(k: usize) -> Result<Self> {
315        check_positive(k, "k")?;
316        Ok(Self {
317            alpha: Array1::from_elem(k, 1.0),
318        })
319    }
320
321    /// Update the prior with observed count data
322    ///
323    /// # Arguments
324    /// * `counts` - Array of observed counts for each category
325    ///
326    /// # Returns
327    /// Updated DirichletMultinomial with posterior parameters
328    pub fn update(&self, counts: ArrayView1<f64>) -> Result<Self> {
329        if counts.len() != self.alpha.len() {
330            return Err(StatsError::DimensionMismatch(format!(
331                "counts length ({}) must match alpha length ({})",
332                counts.len(),
333                self.alpha.len()
334            )));
335        }
336        checkarray_finite(&counts, "counts")?;
337
338        Ok(Self {
339            alpha: &self.alpha + &counts,
340        })
341    }
342
343    /// Compute the posterior mean
344    pub fn posterior_mean(&self) -> Result<Array1<f64>> {
345        let sum = self.alpha.sum();
346        if sum.abs() < f64::EPSILON {
347            return Err(StatsError::domain(
348                "Cannot compute posterior mean: sum of alpha parameters too close to zero",
349            ));
350        }
351        Ok(&self.alpha / sum)
352    }
353
354    /// Compute the posterior mode (MAP estimate)
355    pub fn posterior_mode(&self) -> Result<Option<Array1<f64>>> {
356        let k = self.alpha.len() as f64;
357        if self.alpha.iter().all(|&a| a > 1.0) {
358            let sum = self.alpha.sum();
359            let denominator = sum - k;
360            if denominator.abs() < f64::EPSILON {
361                return Err(StatsError::domain(
362                    "Cannot compute posterior mode: sum - k too close to zero",
363                ));
364            }
365            Ok(Some((&self.alpha - 1.0) / denominator))
366        } else {
367            Ok(None)
368        }
369    }
370
371    /// Compute the marginal variance for each component
372    pub fn posterior_variance(&self) -> Result<Array1<f64>> {
373        let sum = self.alpha.sum();
374        let denominator = sum + 1.0;
375        if denominator.abs() < f64::EPSILON {
376            return Err(StatsError::domain(
377                "Cannot compute posterior variance: sum + 1 too close to zero",
378            ));
379        }
380        let mean = self.posterior_mean()?;
381        Ok(mean.mapv(|p| p * (1.0 - p) / denominator))
382    }
383}
384
385/// Normal-Inverse-Gamma conjugate pair for unknown mean and variance
386///
387/// Prior: NIG(μ₀, λ, α, β)
388/// Likelihood: Normal(μ, σ²) with both unknown
389/// Posterior: NIG(μₙ, λₙ, αₙ, βₙ)
390#[derive(Debug, Clone)]
391pub struct NormalInverseGamma {
392    /// Prior mean
393    pub mu0: f64,
394    /// Prior precision factor
395    pub lambda: f64,
396    /// Shape parameter
397    pub alpha: f64,
398    /// Scale parameter
399    pub beta: f64,
400}
401
402impl NormalInverseGamma {
403    /// Create a new Normal-Inverse-Gamma conjugate prior
404    pub fn new(mu0: f64, lambda: f64, alpha: f64, beta: f64) -> Result<Self> {
405        check_positive(lambda, "lambda")?;
406        check_positive(alpha, "alpha")?;
407        check_positive(beta, "beta")?;
408        Ok(Self {
409            mu0,
410            lambda,
411            alpha,
412            beta,
413        })
414    }
415
416    /// Update the prior with observed data
417    pub fn update(&self, data: ArrayView1<f64>) -> Result<Self> {
418        checkarray_finite(&data, "data")?;
419        let n = data.len() as f64;
420        let data_mean = data.mean();
421
422        // Compute sum of squares
423        let ss = data.mapv(|x| (x - data_mean).powi(2)).sum();
424
425        // Update parameters
426        let lambda_n = self.lambda + n;
427        if lambda_n.abs() < f64::EPSILON {
428            return Err(StatsError::domain(
429                "Cannot update: lambda_n too close to zero",
430            ));
431        }
432        let mu_n = (self.lambda * self.mu0 + n * data_mean) / lambda_n;
433        let alpha_n = self.alpha + n / 2.0;
434        let beta_n = self.beta
435            + 0.5 * ss
436            + 0.5 * self.lambda * n * (data_mean - self.mu0).powi(2) / lambda_n;
437
438        Ok(Self {
439            mu0: mu_n,
440            lambda: lambda_n,
441            alpha: alpha_n,
442            beta: beta_n,
443        })
444    }
445
446    /// Compute the posterior mean of μ
447    pub fn posterior_mean_mu(&self) -> f64 {
448        self.mu0
449    }
450
451    /// Compute the posterior mean of σ²
452    pub fn posterior_mean_variance(&self) -> Result<f64> {
453        let denominator = self.alpha - 1.0;
454        if denominator.abs() < f64::EPSILON {
455            return Err(StatsError::domain(
456                "Cannot compute posterior mean variance: alpha - 1 too close to zero",
457            ));
458        }
459        Ok(self.beta / denominator)
460    }
461
462    /// Compute the marginal posterior variance of μ
463    pub fn posterior_variance_mu(&self) -> Result<f64> {
464        let denominator = self.lambda * (self.alpha - 1.0);
465        if denominator.abs() < f64::EPSILON {
466            return Err(StatsError::domain(
467                "Cannot compute posterior variance mu: lambda * (alpha - 1) too close to zero",
468            ));
469        }
470        Ok(self.beta / denominator)
471    }
472}