1use std::fmt::Debug;
2
3use rv::data::BernoulliSuffStat;
4use rv::data::CategoricalDatum;
5use rv::data::CategoricalSuffStat;
6use rv::data::GaussianSuffStat;
7use rv::data::PoissonSuffStat;
8use rv::dist::Bernoulli;
9use rv::dist::Beta;
10use rv::dist::Categorical;
11use rv::dist::Gamma;
12use rv::dist::Gaussian;
13use rv::dist::NormalInvChiSquared;
14use rv::dist::Poisson;
15use rv::dist::SymmetricDirichlet;
16use rv::traits::ConjugatePrior;
17use rv::traits::HasDensity;
18use rv::traits::HasSuffStat;
19use rv::traits::Mode;
20use rv::traits::Rv;
21use rv::traits::Sampleable;
22use serde::de::DeserializeOwned;
23use serde::Serialize;
24
25use crate::cc::feature::Component;
26use crate::cc::feature::FType;
27use crate::data::SparseContainer;
28use crate::data::TranslateContainer;
29use crate::data::TranslateDatum;
30use crate::stats::prior::csd::CsdHyper;
31use crate::stats::prior::nix::NixHyper;
32use crate::stats::prior::pg::PgHyper;
33use crate::stats::UpdatePrior;
34
35pub trait AccumScore<X: Clone + Default>: Rv<X> + Sync {
40 fn accum_score(&self, scores: &mut [f64], container: &SparseContainer<X>) {
43 use crate::data::AccumScore;
44 container.accum_score(scores, &|x| self.ln_f(x))
45 }
46}
47
48impl<X: CategoricalDatum + Default> AccumScore<X> for Categorical {}
49impl AccumScore<u32> for Poisson {}
50impl AccumScore<f64> for Gaussian {}
51impl AccumScore<bool> for Bernoulli {}
52
53pub trait HasFType {
54 fn ftype() -> FType;
55}
56
57macro_rules! impl_ftype {
58 ($Fx: ty, $ftype: ident) => {
59 impl $crate::cc::traits::HasFType for $Fx {
60 fn ftype() -> $crate::cc::feature::FType {
61 $crate::cc::feature::FType::$ftype
62 }
63 }
64 };
65}
66
67impl_ftype!(Poisson, Count);
68impl_ftype!(Gaussian, Continuous);
69impl_ftype!(Bernoulli, Binary);
70impl_ftype!(Categorical, Categorical);
71
72pub trait LaceDatum:
74 Sync + Serialize + DeserializeOwned + Default + Clone + Debug
75{
76}
77
78impl<X> LaceDatum for X where
79 X: Sync + Serialize + DeserializeOwned + Default + Clone + Debug
80{
81}
82
83pub trait LaceStat:
85 Sync + Serialize + DeserializeOwned + Debug + Clone + PartialEq
86{
87}
88impl<X> LaceStat for X where
89 X: Sync + Serialize + DeserializeOwned + Debug + Clone + PartialEq
90{
91}
92
93pub trait LaceLikelihood<X: LaceDatum>:
95 Sampleable<X>
96 + HasDensity<X>
97 + HasFType
98 + TranslateDatum<X>
99 + TranslateContainer<X>
100 + Mode<X>
101 + AccumScore<X>
102 + HasSuffStat<X>
103 + Serialize
104 + DeserializeOwned
105 + Sync
106 + Into<Component>
107 + Clone
108 + Debug
109 + PartialEq
110{
111 fn ln_f_max(&self) -> Option<f64> {
113 self.mode().map(|x| self.ln_f(&x))
114 }
115}
116
117impl<X, Fx> LaceLikelihood<X> for Fx
118where
119 X: LaceDatum,
120 Fx: Sampleable<X>
121 + HasDensity<X>
122 + HasFType
123 + TranslateDatum<X>
124 + TranslateContainer<X>
125 + Mode<X>
126 + AccumScore<X>
127 + HasSuffStat<X>
128 + Serialize
129 + DeserializeOwned
130 + Sync
131 + Into<Component>
132 + Clone
133 + Debug
134 + PartialEq,
135 Fx::Stat: Sync + Serialize + DeserializeOwned + Clone + Debug,
136{
137}
138
139pub trait LacePrior<X: LaceDatum, Fx: LaceLikelihood<X>, H>:
141 ConjugatePrior<X, Fx>
142 + HasDensity<Fx>
143 + UpdatePrior<X, Fx, H>
144 + Serialize
145 + DeserializeOwned
146 + Sync
147 + Clone
148 + Debug
149{
150 fn empty_suffstat(&self) -> Fx::Stat;
152 fn invalid_temp_component(&self) -> Fx;
159 fn score_column<I: Iterator<Item = Fx::Stat>>(&self, stats: I) -> f64;
161}
162
163impl LacePrior<u32, Categorical, CsdHyper> for SymmetricDirichlet {
164 fn empty_suffstat(&self) -> CategoricalSuffStat {
165 CategoricalSuffStat::new(self.k())
166 }
167
168 fn invalid_temp_component(&self) -> Categorical {
169 Categorical::new_unchecked(vec![0.0; self.k()])
174 }
175
176 fn score_column<I: Iterator<Item = CategoricalSuffStat>>(
177 &self,
178 stats: I,
179 ) -> f64 {
180 let sum_alpha = self.alpha() * self.k() as f64;
181 let a = ::special::Gamma::ln_gamma(sum_alpha).0;
182 let d = ::special::Gamma::ln_gamma(self.alpha()).0 * self.k() as f64;
183 stats
184 .map(|stat| {
185 let b =
186 ::special::Gamma::ln_gamma(sum_alpha + stat.n() as f64).0;
187 let c = stat.counts().iter().fold(0.0, |acc, &ct| {
188 acc + ::special::Gamma::ln_gamma(self.alpha() + ct).0
189 });
190 a - b + c - d
191 })
192 .sum::<f64>()
193 }
194}
195
196#[inline]
197fn poisson_zn(shape: f64, rate: f64, stat: &PoissonSuffStat) -> f64 {
198 let shape_n = shape + stat.sum();
199 let rate_n = rate + stat.n() as f64;
200 let ln_gamma_shape = ::special::Gamma::ln_gamma(shape_n).0;
201 let ln_rate = rate_n.ln();
202 shape_n.mul_add(-ln_rate, ln_gamma_shape)
203}
204
205impl LacePrior<u32, Poisson, PgHyper> for Gamma {
206 fn empty_suffstat(&self) -> PoissonSuffStat {
207 PoissonSuffStat::new()
208 }
209
210 fn invalid_temp_component(&self) -> Poisson {
211 Poisson::new_unchecked(1.0)
212 }
213
214 fn score_column<I: Iterator<Item = PoissonSuffStat>>(
215 &self,
216 stats: I,
217 ) -> f64 {
218 let shape = self.shape();
219 let rate = self.rate();
220 let z0 = {
221 let ln_gamma_shape = ::special::Gamma::ln_gamma(shape).0;
222 let ln_rate = rate.ln();
223 shape.mul_add(-ln_rate, ln_gamma_shape)
224 };
225 stats
226 .map(|stat| {
227 let zn = poisson_zn(shape, rate, &stat);
228 zn - z0 - stat.sum_ln_fact()
229 })
230 .sum::<f64>()
231 }
232}
233
234impl LacePrior<bool, Bernoulli, ()> for Beta {
235 fn empty_suffstat(&self) -> BernoulliSuffStat {
236 BernoulliSuffStat::new()
237 }
238
239 fn invalid_temp_component(&self) -> Bernoulli {
240 Bernoulli::uniform()
241 }
242
243 fn score_column<I: Iterator<Item = BernoulliSuffStat>>(
244 &self,
245 stats: I,
246 ) -> f64 {
247 use rv::data::DataOrSuffStat;
248 let cache = <Beta as ConjugatePrior<bool, Bernoulli>>::ln_m_cache(self);
249 stats
250 .map(|stat| {
251 let x = DataOrSuffStat::SuffStat::<bool, Bernoulli>(&stat);
252 self.ln_m_with_cache(&cache, &x)
253 })
254 .sum::<f64>()
255 }
256}
257
258impl LacePrior<f64, Gaussian, NixHyper> for NormalInvChiSquared {
259 fn empty_suffstat(&self) -> GaussianSuffStat {
260 GaussianSuffStat::new()
261 }
262
263 fn invalid_temp_component(&self) -> Gaussian {
264 Gaussian::standard()
265 }
266
267 fn score_column<I: Iterator<Item = GaussianSuffStat>>(
268 &self,
269 stats: I,
270 ) -> f64 {
271 use rv::data::DataOrSuffStat;
272 let cache =
273 <NormalInvChiSquared as ConjugatePrior<f64, Gaussian>>::ln_m_cache(
274 self,
275 );
276 stats
277 .map(|stat| {
278 let x = DataOrSuffStat::SuffStat(&stat);
279 <NormalInvChiSquared as ConjugatePrior<f64, Gaussian>>::ln_m_with_cache(self, &cache, &x)
280 })
281 .sum::<f64>()
282 }
283}