lace/cc/
traits.rs

1use std::fmt::Debug;
2
3use rv::data::BernoulliSuffStat;
4use rv::data::CategoricalDatum;
5use rv::data::CategoricalSuffStat;
6use rv::data::GaussianSuffStat;
7use rv::data::PoissonSuffStat;
8use rv::dist::Bernoulli;
9use rv::dist::Beta;
10use rv::dist::Categorical;
11use rv::dist::Gamma;
12use rv::dist::Gaussian;
13use rv::dist::NormalInvChiSquared;
14use rv::dist::Poisson;
15use rv::dist::SymmetricDirichlet;
16use rv::traits::ConjugatePrior;
17use rv::traits::HasDensity;
18use rv::traits::HasSuffStat;
19use rv::traits::Mode;
20use rv::traits::Rv;
21use rv::traits::Sampleable;
22use serde::de::DeserializeOwned;
23use serde::Serialize;
24
25use crate::cc::feature::Component;
26use crate::cc::feature::FType;
27use crate::data::SparseContainer;
28use crate::data::TranslateContainer;
29use crate::data::TranslateDatum;
30use crate::stats::prior::csd::CsdHyper;
31use crate::stats::prior::nix::NixHyper;
32use crate::stats::prior::pg::PgHyper;
33use crate::stats::UpdatePrior;
34
35/// Score accumulation for `finite_cpu` and `slice` row transition kernels.
36///
37/// Provides two functions to add the scores (log likelihood) of a vector of
38/// data to a vector of existing scores.
39pub trait AccumScore<X: Clone + Default>: Rv<X> + Sync {
40    // XXX: Default implementations can be improved upon by pre-computing
41    // normalizers
42    fn accum_score(&self, scores: &mut [f64], container: &SparseContainer<X>) {
43        use crate::data::AccumScore;
44        container.accum_score(scores, &|x| self.ln_f(x))
45    }
46}
47
48impl<X: CategoricalDatum + Default> AccumScore<X> for Categorical {}
49impl AccumScore<u32> for Poisson {}
50impl AccumScore<f64> for Gaussian {}
51impl AccumScore<bool> for Bernoulli {}
52
53pub trait HasFType {
54    fn ftype() -> FType;
55}
56
57macro_rules! impl_ftype {
58    ($Fx: ty, $ftype: ident) => {
59        impl $crate::cc::traits::HasFType for $Fx {
60            fn ftype() -> $crate::cc::feature::FType {
61                $crate::cc::feature::FType::$ftype
62            }
63        }
64    };
65}
66
67impl_ftype!(Poisson, Count);
68impl_ftype!(Gaussian, Continuous);
69impl_ftype!(Bernoulli, Binary);
70impl_ftype!(Categorical, Categorical);
71
72/// A Lace-ready datum.
73pub trait LaceDatum:
74    Sync + Serialize + DeserializeOwned + Default + Clone + Debug
75{
76}
77
78impl<X> LaceDatum for X where
79    X: Sync + Serialize + DeserializeOwned + Default + Clone + Debug
80{
81}
82
83/// A Lace-ready datum.
84pub trait LaceStat:
85    Sync + Serialize + DeserializeOwned + Debug + Clone + PartialEq
86{
87}
88impl<X> LaceStat for X where
89    X: Sync + Serialize + DeserializeOwned + Debug + Clone + PartialEq
90{
91}
92
93/// A Lace-ready likelihood function, f(x).
94pub trait LaceLikelihood<X: LaceDatum>:
95    Sampleable<X>
96    + HasDensity<X>
97    + HasFType
98    + TranslateDatum<X>
99    + TranslateContainer<X>
100    + Mode<X>
101    + AccumScore<X>
102    + HasSuffStat<X>
103    + Serialize
104    + DeserializeOwned
105    + Sync
106    + Into<Component>
107    + Clone
108    + Debug
109    + PartialEq
110{
111    /// The maximum value the likelihood can take on for this component
112    fn ln_f_max(&self) -> Option<f64> {
113        self.mode().map(|x| self.ln_f(&x))
114    }
115}
116
117impl<X, Fx> LaceLikelihood<X> for Fx
118where
119    X: LaceDatum,
120    Fx: Sampleable<X>
121        + HasDensity<X>
122        + HasFType
123        + TranslateDatum<X>
124        + TranslateContainer<X>
125        + Mode<X>
126        + AccumScore<X>
127        + HasSuffStat<X>
128        + Serialize
129        + DeserializeOwned
130        + Sync
131        + Into<Component>
132        + Clone
133        + Debug
134        + PartialEq,
135    Fx::Stat: Sync + Serialize + DeserializeOwned + Clone + Debug,
136{
137}
138
139/// A Lace-ready prior π(f)
140pub trait LacePrior<X: LaceDatum, Fx: LaceLikelihood<X>, H>:
141    ConjugatePrior<X, Fx>
142    + HasDensity<Fx>
143    + UpdatePrior<X, Fx, H>
144    + Serialize
145    + DeserializeOwned
146    + Sync
147    + Clone
148    + Debug
149{
150    // Create an empty sufficient statistic for a component
151    fn empty_suffstat(&self) -> Fx::Stat;
152    // Create a dummy component whose parameters **will be** immediately be
153    // overwritten
154    //
155    // # Note
156    // The component must still have the correct dimension for the column. For
157    // example, a categorical column must have the correct `k`.
158    fn invalid_temp_component(&self) -> Fx;
159    // Compute the score of the column for the column reassignment
160    fn score_column<I: Iterator<Item = Fx::Stat>>(&self, stats: I) -> f64;
161}
162
163impl LacePrior<u32, Categorical, CsdHyper> for SymmetricDirichlet {
164    fn empty_suffstat(&self) -> CategoricalSuffStat {
165        CategoricalSuffStat::new(self.k())
166    }
167
168    fn invalid_temp_component(&self) -> Categorical {
169        // XXX: This is not a valid distribution. The weights do not sum to 1. I
170        // want to leave this invalid, because I want it to show up if we use
171        // this someplace we're not supposed to. Anywhere this is supposed to be
172        // use used, the bad weights would be immediately overwritten.
173        Categorical::new_unchecked(vec![0.0; self.k()])
174    }
175
176    fn score_column<I: Iterator<Item = CategoricalSuffStat>>(
177        &self,
178        stats: I,
179    ) -> f64 {
180        let sum_alpha = self.alpha() * self.k() as f64;
181        let a = ::special::Gamma::ln_gamma(sum_alpha).0;
182        let d = ::special::Gamma::ln_gamma(self.alpha()).0 * self.k() as f64;
183        stats
184            .map(|stat| {
185                let b =
186                    ::special::Gamma::ln_gamma(sum_alpha + stat.n() as f64).0;
187                let c = stat.counts().iter().fold(0.0, |acc, &ct| {
188                    acc + ::special::Gamma::ln_gamma(self.alpha() + ct).0
189                });
190                a - b + c - d
191            })
192            .sum::<f64>()
193    }
194}
195
196#[inline]
197fn poisson_zn(shape: f64, rate: f64, stat: &PoissonSuffStat) -> f64 {
198    let shape_n = shape + stat.sum();
199    let rate_n = rate + stat.n() as f64;
200    let ln_gamma_shape = ::special::Gamma::ln_gamma(shape_n).0;
201    let ln_rate = rate_n.ln();
202    shape_n.mul_add(-ln_rate, ln_gamma_shape)
203}
204
205impl LacePrior<u32, Poisson, PgHyper> for Gamma {
206    fn empty_suffstat(&self) -> PoissonSuffStat {
207        PoissonSuffStat::new()
208    }
209
210    fn invalid_temp_component(&self) -> Poisson {
211        Poisson::new_unchecked(1.0)
212    }
213
214    fn score_column<I: Iterator<Item = PoissonSuffStat>>(
215        &self,
216        stats: I,
217    ) -> f64 {
218        let shape = self.shape();
219        let rate = self.rate();
220        let z0 = {
221            let ln_gamma_shape = ::special::Gamma::ln_gamma(shape).0;
222            let ln_rate = rate.ln();
223            shape.mul_add(-ln_rate, ln_gamma_shape)
224        };
225        stats
226            .map(|stat| {
227                let zn = poisson_zn(shape, rate, &stat);
228                zn - z0 - stat.sum_ln_fact()
229            })
230            .sum::<f64>()
231    }
232}
233
234impl LacePrior<bool, Bernoulli, ()> for Beta {
235    fn empty_suffstat(&self) -> BernoulliSuffStat {
236        BernoulliSuffStat::new()
237    }
238
239    fn invalid_temp_component(&self) -> Bernoulli {
240        Bernoulli::uniform()
241    }
242
243    fn score_column<I: Iterator<Item = BernoulliSuffStat>>(
244        &self,
245        stats: I,
246    ) -> f64 {
247        use rv::data::DataOrSuffStat;
248        let cache = <Beta as ConjugatePrior<bool, Bernoulli>>::ln_m_cache(self);
249        stats
250            .map(|stat| {
251                let x = DataOrSuffStat::SuffStat::<bool, Bernoulli>(&stat);
252                self.ln_m_with_cache(&cache, &x)
253            })
254            .sum::<f64>()
255    }
256}
257
258impl LacePrior<f64, Gaussian, NixHyper> for NormalInvChiSquared {
259    fn empty_suffstat(&self) -> GaussianSuffStat {
260        GaussianSuffStat::new()
261    }
262
263    fn invalid_temp_component(&self) -> Gaussian {
264        Gaussian::standard()
265    }
266
267    fn score_column<I: Iterator<Item = GaussianSuffStat>>(
268        &self,
269        stats: I,
270    ) -> f64 {
271        use rv::data::DataOrSuffStat;
272        let cache =
273            <NormalInvChiSquared as ConjugatePrior<f64, Gaussian>>::ln_m_cache(
274                self,
275            );
276        stats
277            .map(|stat| {
278                let x = DataOrSuffStat::SuffStat(&stat);
279                <NormalInvChiSquared as ConjugatePrior<f64, Gaussian>>::ln_m_with_cache(self, &cache, &x)
280            })
281            .sum::<f64>()
282    }
283}