lace_stats/prior/
csd.rs

1use crate::rv::data::CategoricalDatum;
2use crate::rv::dist::{Categorical, InvGamma, SymmetricDirichlet};
3use crate::rv::traits::*;
4use rand::Rng;
5use serde::{Deserialize, Serialize};
6
7use crate::mh::mh_prior;
8use crate::UpdatePrior;
9
10/// Default `Csd` for Geweke testing
11pub fn geweke(k: usize) -> SymmetricDirichlet {
12    SymmetricDirichlet::new_unchecked(1.0, k)
13}
14
15/// Draw the prior from the hyper-prior
16pub fn from_hyper(
17    k: usize,
18    hyper: CsdHyper,
19    mut rng: &mut impl Rng,
20) -> SymmetricDirichlet {
21    hyper.draw(k, &mut rng)
22}
23
24/// Build a vague hyper-prior given `k` and draws the prior from that
25pub fn vague(k: usize) -> SymmetricDirichlet {
26    SymmetricDirichlet::new_unchecked(0.5, k)
27}
28
29impl<X: CategoricalDatum> UpdatePrior<X, Categorical, CsdHyper>
30    for SymmetricDirichlet
31{
32    fn update_prior<R: Rng>(
33        &mut self,
34        components: &[&Categorical],
35        hyper: &CsdHyper,
36        rng: &mut R,
37    ) -> f64 {
38        let mh_result = {
39            let k = self.k();
40            let kf = k as f64;
41
42            let loglike = |alpha: &f64| {
43                // Pre-compute costly gamma_ln functions
44                let sum_ln_gamma = special::Gamma::ln_gamma(*alpha).0 * kf;
45                let ln_gamma_sum = special::Gamma::ln_gamma(alpha * kf).0;
46                let am1 = alpha - 1.0;
47
48                components
49                    .iter()
50                    .map(|cpnt| {
51                        let term = cpnt
52                            .ln_weights()
53                            .iter()
54                            .map(|&ln_w| am1 * ln_w)
55                            .sum::<f64>();
56                        term - (sum_ln_gamma - ln_gamma_sum)
57                    })
58                    .sum::<f64>()
59            };
60
61            mh_prior(
62                self.alpha(),
63                loglike,
64                |rng| hyper.pr_alpha.draw(rng),
65                lace_consts::MH_PRIOR_ITERS,
66                rng,
67            )
68        };
69
70        self.set_alpha(mh_result.x).unwrap();
71        mh_result.score_x + hyper.pr_alpha.ln_f(&mh_result.x)
72    }
73}
74
75#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
76pub struct CsdHyper {
77    pub pr_alpha: InvGamma,
78}
79
80impl Default for CsdHyper {
81    fn default() -> Self {
82        CsdHyper {
83            pr_alpha: InvGamma::new(1.0, 1.0).unwrap(),
84        }
85    }
86}
87
88impl CsdHyper {
89    pub fn new(shape: f64, rate: f64) -> Self {
90        CsdHyper {
91            pr_alpha: InvGamma::new(shape, rate).unwrap(),
92        }
93    }
94
95    /// A restrictive prior to confine Geweke.
96    ///
97    /// Since the geweke test seeks to draw samples from the joint of the prior
98    /// and the data, p(x, θ), and since θ is indluenced by the hyper-prior, if
99    /// the hyper parameters are not tight, the data can go crazy and cause a
100    /// bunch of math errors.
101    pub fn geweke() -> Self {
102        CsdHyper {
103            pr_alpha: InvGamma::new(30.0, 29.0).unwrap(),
104        }
105    }
106
107    /// α ~ Gamma(k + 1, 1)
108    pub fn vague(k: usize) -> Self {
109        CsdHyper {
110            pr_alpha: InvGamma::new(k as f64 + 1.0, 1.0).unwrap(),
111        }
112    }
113
114    /// Draw a `Csd` from the hyper-prior
115    pub fn draw(&self, k: usize, mut rng: &mut impl Rng) -> SymmetricDirichlet {
116        // SymmetricDirichlet::new(self.pr_alpha.draw(&mut rng), k);
117        let alpha = self.pr_alpha.draw(&mut rng);
118        SymmetricDirichlet::new_unchecked(alpha, k)
119    }
120}