1use crate::distribution::{betai, gammainc, ln_gamma, Distribution};
12use cyanea_core::{CyaneaError, Result};
13
14#[derive(Debug, Clone, Copy)]
21pub struct Beta {
22 alpha: f64,
23 beta: f64,
24}
25
26impl Beta {
27 pub fn new(alpha: f64, beta: f64) -> Result<Self> {
31 if alpha <= 0.0 || beta <= 0.0 {
32 return Err(CyaneaError::InvalidInput(
33 "Beta: alpha and beta must be positive".into(),
34 ));
35 }
36 Ok(Self { alpha, beta })
37 }
38
39 pub fn alpha(&self) -> f64 {
41 self.alpha
42 }
43
44 pub fn beta(&self) -> f64 {
46 self.beta
47 }
48
49 pub fn update_binomial(&self, successes: u64, trials: u64) -> Self {
51 Self {
52 alpha: self.alpha + successes as f64,
53 beta: self.beta + (trials - successes) as f64,
54 }
55 }
56}
57
58impl Distribution for Beta {
59 fn pdf(&self, x: f64) -> f64 {
60 if x <= 0.0 || x >= 1.0 {
61 return 0.0;
62 }
63 let ln_beta_fn = ln_gamma(self.alpha) + ln_gamma(self.beta)
64 - ln_gamma(self.alpha + self.beta);
65 let ln_pdf = (self.alpha - 1.0) * x.ln()
66 + (self.beta - 1.0) * (1.0 - x).ln()
67 - ln_beta_fn;
68 ln_pdf.exp()
69 }
70
71 fn cdf(&self, x: f64) -> f64 {
72 if x <= 0.0 {
73 return 0.0;
74 }
75 if x >= 1.0 {
76 return 1.0;
77 }
78 betai(self.alpha, self.beta, x).unwrap_or(0.0)
79 }
80
81 fn mean(&self) -> f64 {
82 self.alpha / (self.alpha + self.beta)
83 }
84
85 fn variance(&self) -> f64 {
86 let ab = self.alpha + self.beta;
87 (self.alpha * self.beta) / (ab * ab * (ab + 1.0))
88 }
89}
90
91#[derive(Debug, Clone, Copy)]
99pub struct Gamma {
100 shape: f64,
101 rate: f64,
102}
103
104impl Gamma {
105 pub fn new(shape: f64, rate: f64) -> Result<Self> {
109 if shape <= 0.0 || rate <= 0.0 {
110 return Err(CyaneaError::InvalidInput(
111 "Gamma: shape and rate must be positive".into(),
112 ));
113 }
114 Ok(Self { shape, rate })
115 }
116
117 pub fn shape(&self) -> f64 {
119 self.shape
120 }
121
122 pub fn rate(&self) -> f64 {
124 self.rate
125 }
126
127 pub fn update_poisson(&self, count: u64) -> Self {
129 Self {
130 shape: self.shape + count as f64,
131 rate: self.rate + 1.0,
132 }
133 }
134
135 pub fn update_poisson_batch(&self, counts: &[u64]) -> Self {
137 let total: u64 = counts.iter().sum();
138 Self {
139 shape: self.shape + total as f64,
140 rate: self.rate + counts.len() as f64,
141 }
142 }
143}
144
145impl Distribution for Gamma {
146 fn pdf(&self, x: f64) -> f64 {
147 if x <= 0.0 {
148 return 0.0;
149 }
150 let ln_pdf = self.shape * self.rate.ln() - ln_gamma(self.shape)
151 + (self.shape - 1.0) * x.ln()
152 - self.rate * x;
153 ln_pdf.exp()
154 }
155
156 fn cdf(&self, x: f64) -> f64 {
157 if x <= 0.0 {
158 return 0.0;
159 }
160 gammainc(self.shape, self.rate * x).unwrap_or(0.0)
163 }
164
165 fn mean(&self) -> f64 {
166 self.shape / self.rate
167 }
168
169 fn variance(&self) -> f64 {
170 self.shape / (self.rate * self.rate)
171 }
172}
173
174#[derive(Debug, Clone, Copy)]
181pub struct NormalConjugate {
182 prior_mu: f64,
183 prior_var: f64,
184 obs_var: f64,
185}
186
187impl NormalConjugate {
188 pub fn new(prior_mu: f64, prior_var: f64, obs_var: f64) -> Result<Self> {
194 if prior_var <= 0.0 {
195 return Err(CyaneaError::InvalidInput(
196 "NormalConjugate: prior_var must be positive".into(),
197 ));
198 }
199 if obs_var <= 0.0 {
200 return Err(CyaneaError::InvalidInput(
201 "NormalConjugate: obs_var must be positive".into(),
202 ));
203 }
204 Ok(Self {
205 prior_mu,
206 prior_var,
207 obs_var,
208 })
209 }
210
211 pub fn update(&self, observation: f64) -> Self {
213 let prior_prec = 1.0 / self.prior_var;
214 let obs_prec = 1.0 / self.obs_var;
215 let post_prec = prior_prec + obs_prec;
216 let post_var = 1.0 / post_prec;
217 let post_mu = (prior_prec * self.prior_mu + obs_prec * observation) / post_prec;
218 Self {
219 prior_mu: post_mu,
220 prior_var: post_var,
221 obs_var: self.obs_var,
222 }
223 }
224
225 pub fn update_batch(&self, observations: &[f64]) -> Self {
227 let n = observations.len() as f64;
228 if n == 0.0 {
229 return *self;
230 }
231 let obs_mean: f64 = observations.iter().sum::<f64>() / n;
232 let prior_prec = 1.0 / self.prior_var;
233 let obs_prec = n / self.obs_var;
234 let post_prec = prior_prec + obs_prec;
235 let post_var = 1.0 / post_prec;
236 let post_mu = (prior_prec * self.prior_mu + obs_prec * obs_mean) / post_prec;
237 Self {
238 prior_mu: post_mu,
239 prior_var: post_var,
240 obs_var: self.obs_var,
241 }
242 }
243
244 pub fn posterior_mean(&self) -> f64 {
246 self.prior_mu
247 }
248
249 pub fn posterior_variance(&self) -> f64 {
251 self.prior_var
252 }
253}
254
255#[derive(Debug, Clone)]
262pub struct Dirichlet {
263 alpha: Vec<f64>,
264}
265
266impl Dirichlet {
267 pub fn new(alpha: Vec<f64>) -> Result<Self> {
271 if alpha.len() < 2 {
272 return Err(CyaneaError::InvalidInput(
273 "Dirichlet: need at least 2 categories".into(),
274 ));
275 }
276 if alpha.iter().any(|&a| a <= 0.0) {
277 return Err(CyaneaError::InvalidInput(
278 "Dirichlet: all alpha values must be positive".into(),
279 ));
280 }
281 Ok(Self { alpha })
282 }
283
284 pub fn symmetric(k: usize, alpha: f64) -> Result<Self> {
287 if k < 2 {
288 return Err(CyaneaError::InvalidInput(
289 "Dirichlet: need at least 2 categories".into(),
290 ));
291 }
292 if alpha <= 0.0 {
293 return Err(CyaneaError::InvalidInput(
294 "Dirichlet: alpha must be positive".into(),
295 ));
296 }
297 Ok(Self {
298 alpha: vec![alpha; k],
299 })
300 }
301
302 pub fn alpha(&self) -> &[f64] {
304 &self.alpha
305 }
306
307 pub fn update_multinomial(&self, counts: &[u64]) -> Self {
313 assert_eq!(
314 counts.len(),
315 self.alpha.len(),
316 "counts length must match alpha length"
317 );
318 Self {
319 alpha: self
320 .alpha
321 .iter()
322 .zip(counts.iter())
323 .map(|(&a, &c)| a + c as f64)
324 .collect(),
325 }
326 }
327
328 pub fn mean(&self) -> Vec<f64> {
330 let sum: f64 = self.alpha.iter().sum();
331 self.alpha.iter().map(|&a| a / sum).collect()
332 }
333
334 pub fn variance(&self) -> Vec<f64> {
336 let sum: f64 = self.alpha.iter().sum();
337 let denom = sum * sum * (sum + 1.0);
338 self.alpha.iter().map(|&a| a * (sum - a) / denom).collect()
339 }
340
341 pub fn ln_pdf(&self, x: &[f64]) -> Result<f64> {
348 if x.len() != self.alpha.len() {
349 return Err(CyaneaError::InvalidInput(
350 "Dirichlet::ln_pdf: x length must match alpha length".into(),
351 ));
352 }
353 let sum: f64 = x.iter().sum();
354 if (sum - 1.0).abs() > 1e-6 {
355 return Err(CyaneaError::InvalidInput(
356 "Dirichlet::ln_pdf: x must sum to 1".into(),
357 ));
358 }
359
360 let alpha_sum: f64 = self.alpha.iter().sum();
361 let mut ln_b = -ln_gamma(alpha_sum);
362 for &a in &self.alpha {
363 ln_b += ln_gamma(a);
364 }
365
366 let mut result = -ln_b;
367 for (xi, &ai) in x.iter().zip(self.alpha.iter()) {
368 if *xi <= 0.0 {
369 return Err(CyaneaError::InvalidInput(
370 "Dirichlet::ln_pdf: all x values must be positive".into(),
371 ));
372 }
373 result += (ai - 1.0) * xi.ln();
374 }
375
376 Ok(result)
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383
384 const TOL: f64 = 1e-6;
385
386 #[test]
389 fn beta_uniform_prior() {
390 let prior = Beta::new(1.0, 1.0).unwrap();
391 assert!((prior.mean() - 0.5).abs() < TOL);
392 }
393
394 #[test]
395 fn beta_conjugacy() {
396 let prior = Beta::new(1.0, 1.0).unwrap();
398 let post = prior.update_binomial(3, 10);
399 assert!((post.alpha() - 4.0).abs() < TOL);
400 assert!((post.beta() - 8.0).abs() < TOL);
401 assert!((post.mean() - 4.0 / 12.0).abs() < TOL);
402 }
403
404 #[test]
405 fn beta_pdf_at_mode() {
406 let b = Beta::new(2.0, 5.0).unwrap();
408 let pdf_at_mode = b.pdf(0.2);
409 assert!(pdf_at_mode > b.pdf(0.1));
411 assert!(pdf_at_mode > b.pdf(0.5));
412 }
413
414 #[test]
415 fn beta_cdf_boundaries() {
416 let b = Beta::new(2.0, 3.0).unwrap();
417 assert_eq!(b.cdf(0.0), 0.0);
418 assert!((b.cdf(1.0) - 1.0).abs() < TOL);
419 }
420
421 #[test]
422 fn beta_cdf_midpoint() {
423 let b = Beta::new(1.0, 1.0).unwrap();
425 assert!((b.cdf(0.5) - 0.5).abs() < TOL);
426 }
427
428 #[test]
429 fn beta_invalid() {
430 assert!(Beta::new(0.0, 1.0).is_err());
431 assert!(Beta::new(1.0, -1.0).is_err());
432 }
433
434 #[test]
437 fn gamma_mean_variance() {
438 let g = Gamma::new(3.0, 2.0).unwrap();
439 assert!((g.mean() - 1.5).abs() < TOL);
440 assert!((g.variance() - 0.75).abs() < TOL);
441 }
442
443 #[test]
444 fn gamma_conjugacy_poisson() {
445 let prior = Gamma::new(2.0, 1.0).unwrap();
447 let post = prior.update_poisson(5);
448 assert!((post.shape() - 7.0).abs() < TOL);
449 assert!((post.rate() - 2.0).abs() < TOL);
450 }
451
452 #[test]
453 fn gamma_conjugacy_batch() {
454 let prior = Gamma::new(2.0, 1.0).unwrap();
456 let post = prior.update_poisson_batch(&[3, 5, 2]);
457 assert!((post.shape() - 12.0).abs() < TOL);
458 assert!((post.rate() - 4.0).abs() < TOL);
459 }
460
461 #[test]
462 fn gamma_cdf() {
463 let g = Gamma::new(1.0, 1.0).unwrap();
465 let x = 2.0;
466 let expected = 1.0 - (-x as f64).exp();
467 assert!((g.cdf(x) - expected).abs() < 1e-8);
468 }
469
470 #[test]
471 fn gamma_invalid() {
472 assert!(Gamma::new(0.0, 1.0).is_err());
473 assert!(Gamma::new(1.0, 0.0).is_err());
474 }
475
476 #[test]
479 fn normal_conjugate_single_update() {
480 let prior = NormalConjugate::new(0.0, 1.0, 1.0).unwrap();
481 let post = prior.update(2.0);
482 assert!((post.posterior_mean() - 1.0).abs() < TOL);
484 assert!((post.posterior_variance() - 0.5).abs() < TOL);
486 }
487
488 #[test]
489 fn normal_conjugate_batch_update() {
490 let prior = NormalConjugate::new(0.0, 1.0, 1.0).unwrap();
491 let post = prior.update_batch(&[2.0, 4.0]);
492 assert!((post.posterior_mean() - 2.0).abs() < TOL);
497 assert!((post.posterior_variance() - 1.0 / 3.0).abs() < TOL);
498 }
499
500 #[test]
501 fn normal_conjugate_empty_batch() {
502 let prior = NormalConjugate::new(5.0, 2.0, 1.0).unwrap();
503 let post = prior.update_batch(&[]);
504 assert!((post.posterior_mean() - 5.0).abs() < TOL);
505 assert!((post.posterior_variance() - 2.0).abs() < TOL);
506 }
507
508 #[test]
509 fn normal_conjugate_precision_shrinkage() {
510 let prior = NormalConjugate::new(0.0, 0.01, 100.0).unwrap();
512 let post = prior.update(100.0);
513 assert!(post.posterior_mean().abs() < 0.02);
516 }
517
518 #[test]
519 fn normal_conjugate_invalid() {
520 assert!(NormalConjugate::new(0.0, 0.0, 1.0).is_err());
521 assert!(NormalConjugate::new(0.0, 1.0, 0.0).is_err());
522 }
523
524 #[test]
527 fn dirichlet_symmetric_mean() {
528 let d = Dirichlet::symmetric(4, 1.0).unwrap();
529 let mean = d.mean();
530 assert_eq!(mean.len(), 4);
531 for m in &mean {
532 assert!((m - 0.25).abs() < TOL);
533 }
534 }
535
536 #[test]
537 fn dirichlet_conjugacy_multinomial() {
538 let prior = Dirichlet::symmetric(3, 1.0).unwrap();
539 let post = prior.update_multinomial(&[10, 5, 15]);
540 let expected = [11.0, 6.0, 16.0];
541 for (a, e) in post.alpha().iter().zip(expected.iter()) {
542 assert!((a - e).abs() < TOL);
543 }
544 }
545
546 #[test]
547 fn dirichlet_ln_pdf() {
548 let d = Dirichlet::symmetric(3, 1.0).unwrap();
550 let ln_pdf = d.ln_pdf(&[1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]).unwrap();
551 assert!((ln_pdf - 2.0_f64.ln()).abs() < 1e-6);
552 }
553
554 #[test]
555 fn dirichlet_invalid() {
556 assert!(Dirichlet::new(vec![1.0]).is_err()); assert!(Dirichlet::new(vec![1.0, -1.0]).is_err());
558 assert!(Dirichlet::symmetric(1, 1.0).is_err());
559 assert!(Dirichlet::symmetric(3, 0.0).is_err());
560 }
561
562 #[test]
563 fn dirichlet_ln_pdf_invalid() {
564 let d = Dirichlet::symmetric(3, 1.0).unwrap();
565 assert!(d.ln_pdf(&[0.5, 0.5]).is_err()); assert!(d.ln_pdf(&[0.5, 0.3, 0.1]).is_err()); }
568}