rv/data/stat/
invgaussian.rs1#[cfg(feature = "serde1")]
2use serde::{Deserialize, Serialize};
3
4use crate::data::DataOrSuffStat;
5use crate::dist::InvGaussian;
6use crate::traits::SuffStat;
7
8#[derive(Debug, Clone, PartialEq)]
13#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
14#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
15pub struct InvGaussianSuffStat {
16 n: usize,
18 sum_x: f64,
20 sum_inv_x: f64,
22 sum_ln_x: f64,
24}
25
26impl InvGaussianSuffStat {
27 #[inline]
28 #[must_use]
29 pub fn new() -> Self {
30 InvGaussianSuffStat {
31 n: 0,
32 sum_x: 0.0,
33 sum_inv_x: 0.0,
34 sum_ln_x: 0.0,
35 }
36 }
37
38 #[inline]
64 #[must_use]
65 pub fn from_parts_unchecked(
66 n: usize,
67 sum_x: f64,
68 sum_inv_x: f64,
69 sum_ln_x: f64,
70 ) -> Self {
71 InvGaussianSuffStat {
72 n,
73 sum_x,
74 sum_inv_x,
75 sum_ln_x,
76 }
77 }
78
79 #[inline]
81 #[must_use]
82 pub fn n(&self) -> usize {
83 self.n
84 }
85
86 #[inline]
88 #[must_use]
89 pub fn mean(&self) -> f64 {
90 self.sum_x / self.n as f64
91 }
92
93 #[inline]
95 #[must_use]
96 pub fn sum_x(&self) -> f64 {
97 self.sum_x
98 }
99
100 #[inline]
102 #[must_use]
103 pub fn sum_inv_x(&self) -> f64 {
104 self.sum_inv_x
105 }
106
107 #[inline]
108 #[must_use]
109 pub fn sum_ln_x(&self) -> f64 {
110 self.sum_ln_x
111 }
112}
113
114impl Default for InvGaussianSuffStat {
115 fn default() -> Self {
116 InvGaussianSuffStat::new()
117 }
118}
119
120macro_rules! impl_invgaussian_suffstat {
121 ($kind:ty) => {
122 impl<'a> From<&'a InvGaussianSuffStat>
123 for DataOrSuffStat<'a, $kind, InvGaussian>
124 {
125 fn from(stat: &'a InvGaussianSuffStat) -> Self {
126 DataOrSuffStat::SuffStat(stat)
127 }
128 }
129
130 impl<'a> From<&'a Vec<$kind>>
131 for DataOrSuffStat<'a, $kind, InvGaussian>
132 {
133 fn from(xs: &'a Vec<$kind>) -> Self {
134 DataOrSuffStat::Data(xs.as_slice())
135 }
136 }
137
138 impl<'a> From<&'a [$kind]> for DataOrSuffStat<'a, $kind, InvGaussian> {
139 fn from(xs: &'a [$kind]) -> Self {
140 DataOrSuffStat::Data(xs)
141 }
142 }
143
144 impl From<&Vec<$kind>> for InvGaussianSuffStat {
145 fn from(xs: &Vec<$kind>) -> Self {
146 let mut stat = InvGaussianSuffStat::new();
147 stat.observe_many(xs);
148 stat
149 }
150 }
151
152 impl From<&[$kind]> for InvGaussianSuffStat {
153 fn from(xs: &[$kind]) -> Self {
154 let mut stat = InvGaussianSuffStat::new();
155 stat.observe_many(xs);
156 stat
157 }
158 }
159
160 impl SuffStat<$kind> for InvGaussianSuffStat {
161 fn n(&self) -> usize {
162 self.n
163 }
164
165 fn observe(&mut self, x: &$kind) {
166 let xf = f64::from(*x);
167
168 self.n += 1;
169
170 self.sum_x += xf;
171 self.sum_inv_x += xf.recip();
172 self.sum_ln_x += xf.ln();
173 }
174
175 fn forget(&mut self, x: &$kind) {
176 if self.n > 1 {
177 let xf = f64::from(*x);
178
179 self.sum_x -= xf;
180 self.sum_inv_x -= xf.recip();
181 self.sum_ln_x -= xf.ln();
182 self.n -= 1;
183 } else {
184 self.n = 0;
185 self.sum_x = 0.0;
186 self.sum_inv_x = 0.0;
187 self.sum_ln_x = 0.0;
188 }
189 }
190 fn merge(&mut self, other: Self) {
191 self.n += other.n;
192 self.sum_x += other.sum_x;
193 self.sum_inv_x += other.sum_inv_x;
194 self.sum_ln_x += other.sum_ln_x;
195 }
196 }
197 };
198}
199
200impl_invgaussian_suffstat!(f32);
201impl_invgaussian_suffstat!(f64);
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn observe_forget() {
209 let mut stat = InvGaussianSuffStat::new();
210
211 stat.observe(&0.1);
212 stat.observe(&0.2);
213
214 assert_eq!(stat.n(), 2);
215 assert::close(stat.sum_x, 0.1_f64 + 0.2_f64, 1e-10);
216 assert::close(
217 stat.sum_inv_x,
218 (0.1_f64).recip() + (0.2_f64).recip(),
219 1e-10,
220 );
221 assert::close(stat.sum_ln_x, (0.1_f64).ln() + (0.2_f64).ln(), 1e-10);
222
223 stat.forget(&0.1);
224
225 assert_eq!(stat.n(), 1);
226 assert::close(stat.sum_x, 0.2_f64, 1e-10);
227 assert::close(stat.sum_inv_x, (0.2_f64).recip(), 1e-10);
228 assert::close(stat.sum_ln_x, (0.2_f64).ln(), 1e-10);
229
230 stat.forget(&0.2);
231
232 assert_eq!(stat.n(), 0);
233 assert_eq!(stat.sum_ln_x, 0.0);
234 }
235
236 #[test]
237 fn merge() {
238 let mut a = InvGaussianSuffStat::new();
239 let mut b = InvGaussianSuffStat::new();
240 let mut c = InvGaussianSuffStat::new();
241
242 a.observe_many(&[0.1_f64, 0.2, 0.3]);
243 b.observe_many(&[0.9_f64, 0.8, 0.7]);
244
245 c.observe_many(&[0.1_f64, 0.2, 0.3, 0.9, 0.8, 0.7]);
246
247 <InvGaussianSuffStat as SuffStat<f64>>::merge(&mut a, b);
248
249 assert_eq!(a.n(), c.n());
250 assert::close(a.sum_x(), c.sum_x(), 1e-10);
251 assert::close(a.sum_inv_x(), c.sum_inv_x(), 1e-10);
252 assert::close(a.sum_ln_x(), c.sum_ln_x(), 1e-10);
253 }
254}