1use crate::rv::data::CategoricalDatum;
2use crate::rv::dist::{Categorical, InvGamma, SymmetricDirichlet};
3use crate::rv::traits::*;
4use rand::Rng;
5use serde::{Deserialize, Serialize};
6
7use crate::mh::mh_prior;
8use crate::UpdatePrior;
9
10pub fn geweke(k: usize) -> SymmetricDirichlet {
12 SymmetricDirichlet::new_unchecked(1.0, k)
13}
14
15pub fn from_hyper(
17 k: usize,
18 hyper: CsdHyper,
19 mut rng: &mut impl Rng,
20) -> SymmetricDirichlet {
21 hyper.draw(k, &mut rng)
22}
23
24pub fn vague(k: usize) -> SymmetricDirichlet {
26 SymmetricDirichlet::new_unchecked(0.5, k)
27}
28
29impl<X: CategoricalDatum> UpdatePrior<X, Categorical, CsdHyper>
30 for SymmetricDirichlet
31{
32 fn update_prior<R: Rng>(
33 &mut self,
34 components: &[&Categorical],
35 hyper: &CsdHyper,
36 rng: &mut R,
37 ) -> f64 {
38 let mh_result = {
39 let k = self.k();
40 let kf = k as f64;
41
42 let loglike = |alpha: &f64| {
43 let sum_ln_gamma = special::Gamma::ln_gamma(*alpha).0 * kf;
45 let ln_gamma_sum = special::Gamma::ln_gamma(alpha * kf).0;
46 let am1 = alpha - 1.0;
47
48 components
49 .iter()
50 .map(|cpnt| {
51 let term = cpnt
52 .ln_weights()
53 .iter()
54 .map(|&ln_w| am1 * ln_w)
55 .sum::<f64>();
56 term - (sum_ln_gamma - ln_gamma_sum)
57 })
58 .sum::<f64>()
59 };
60
61 mh_prior(
62 self.alpha(),
63 loglike,
64 |rng| hyper.pr_alpha.draw(rng),
65 lace_consts::MH_PRIOR_ITERS,
66 rng,
67 )
68 };
69
70 self.set_alpha(mh_result.x).unwrap();
71 mh_result.score_x + hyper.pr_alpha.ln_f(&mh_result.x)
72 }
73}
74
75#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
76pub struct CsdHyper {
77 pub pr_alpha: InvGamma,
78}
79
80impl Default for CsdHyper {
81 fn default() -> Self {
82 CsdHyper {
83 pr_alpha: InvGamma::new(1.0, 1.0).unwrap(),
84 }
85 }
86}
87
88impl CsdHyper {
89 pub fn new(shape: f64, rate: f64) -> Self {
90 CsdHyper {
91 pr_alpha: InvGamma::new(shape, rate).unwrap(),
92 }
93 }
94
95 pub fn geweke() -> Self {
102 CsdHyper {
103 pr_alpha: InvGamma::new(30.0, 29.0).unwrap(),
104 }
105 }
106
107 pub fn vague(k: usize) -> Self {
109 CsdHyper {
110 pr_alpha: InvGamma::new(k as f64 + 1.0, 1.0).unwrap(),
111 }
112 }
113
114 pub fn draw(&self, k: usize, mut rng: &mut impl Rng) -> SymmetricDirichlet {
116 let alpha = self.pr_alpha.draw(&mut rng);
118 SymmetricDirichlet::new_unchecked(alpha, k)
119 }
120}