Skip to main content

rv/data/stat/
invgaussian.rs

1#[cfg(feature = "serde1")]
2use serde::{Deserialize, Serialize};
3
4use crate::data::DataOrSuffStat;
5use crate::dist::InvGaussian;
6use crate::traits::SuffStat;
7
8/// Gaussian sufficient statistic.
9///
10/// Holds the number of observations, their sum, and the sum of their squared
11/// values.
12#[derive(Debug, Clone, PartialEq)]
13#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
14#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
15pub struct InvGaussianSuffStat {
16    /// Number of observations
17    n: usize,
18    /// sum of `x`
19    sum_x: f64,
20    /// sum of `1/x`
21    sum_inv_x: f64,
22    /// sum of ln(x)
23    sum_ln_x: f64,
24}
25
26impl InvGaussianSuffStat {
27    #[inline]
28    #[must_use]
29    pub fn new() -> Self {
30        InvGaussianSuffStat {
31            n: 0,
32            sum_x: 0.0,
33            sum_inv_x: 0.0,
34            sum_ln_x: 0.0,
35        }
36    }
37
38    /// Create a sufficient statistic from components without checking whether
39    /// they are valid.
40    ///
41    /// # Example
42    /// ```rust
43    /// use rv::data::InvGaussianSuffStat;
44    /// use rv::prelude::SuffStat;
45    ///
46    /// let data: Vec<f64> = vec![0.1, 0.2, 0.3];
47    ///
48    /// let mut stat_b = InvGaussianSuffStat::new();
49    /// stat_b.observe_many(&data);
50    ///
51    /// let n = data.len();
52    /// let sum_x = data.iter().sum();
53    /// let sum_inv_x = data.iter().map(|x: &f64| x.recip()).sum();
54    /// let sum_ln_x = data.iter().map(|x: &f64| x.ln()).sum();
55    ///
56    /// let stat_a = InvGaussianSuffStat::from_parts_unchecked(n, sum_x, sum_inv_x, sum_ln_x);
57    ///
58    /// assert_eq!(stat_a.n(), stat_b.n());
59    /// assert::close(stat_a.sum_x(), stat_b.sum_x(), 1e-10);
60    /// assert::close(stat_a.sum_inv_x(), stat_b.sum_inv_x(), 1e-10);
61    /// assert::close(stat_a.sum_ln_x(), stat_b.sum_ln_x(), 1e-10);
62    /// ```
63    #[inline]
64    #[must_use]
65    pub fn from_parts_unchecked(
66        n: usize,
67        sum_x: f64,
68        sum_inv_x: f64,
69        sum_ln_x: f64,
70    ) -> Self {
71        InvGaussianSuffStat {
72            n,
73            sum_x,
74            sum_inv_x,
75            sum_ln_x,
76        }
77    }
78
79    /// Get the number of observations
80    #[inline]
81    #[must_use]
82    pub fn n(&self) -> usize {
83        self.n
84    }
85
86    /// Get the sample mean
87    #[inline]
88    #[must_use]
89    pub fn mean(&self) -> f64 {
90        self.sum_x / self.n as f64
91    }
92
93    /// Sum of `x`
94    #[inline]
95    #[must_use]
96    pub fn sum_x(&self) -> f64 {
97        self.sum_x
98    }
99
100    /// Sum of `1/x`
101    #[inline]
102    #[must_use]
103    pub fn sum_inv_x(&self) -> f64 {
104        self.sum_inv_x
105    }
106
107    #[inline]
108    #[must_use]
109    pub fn sum_ln_x(&self) -> f64 {
110        self.sum_ln_x
111    }
112}
113
114impl Default for InvGaussianSuffStat {
115    fn default() -> Self {
116        InvGaussianSuffStat::new()
117    }
118}
119
120macro_rules! impl_invgaussian_suffstat {
121    ($kind:ty) => {
122        impl<'a> From<&'a InvGaussianSuffStat>
123            for DataOrSuffStat<'a, $kind, InvGaussian>
124        {
125            fn from(stat: &'a InvGaussianSuffStat) -> Self {
126                DataOrSuffStat::SuffStat(stat)
127            }
128        }
129
130        impl<'a> From<&'a Vec<$kind>>
131            for DataOrSuffStat<'a, $kind, InvGaussian>
132        {
133            fn from(xs: &'a Vec<$kind>) -> Self {
134                DataOrSuffStat::Data(xs.as_slice())
135            }
136        }
137
138        impl<'a> From<&'a [$kind]> for DataOrSuffStat<'a, $kind, InvGaussian> {
139            fn from(xs: &'a [$kind]) -> Self {
140                DataOrSuffStat::Data(xs)
141            }
142        }
143
144        impl From<&Vec<$kind>> for InvGaussianSuffStat {
145            fn from(xs: &Vec<$kind>) -> Self {
146                let mut stat = InvGaussianSuffStat::new();
147                stat.observe_many(xs);
148                stat
149            }
150        }
151
152        impl From<&[$kind]> for InvGaussianSuffStat {
153            fn from(xs: &[$kind]) -> Self {
154                let mut stat = InvGaussianSuffStat::new();
155                stat.observe_many(xs);
156                stat
157            }
158        }
159
160        impl SuffStat<$kind> for InvGaussianSuffStat {
161            fn n(&self) -> usize {
162                self.n
163            }
164
165            fn observe(&mut self, x: &$kind) {
166                let xf = f64::from(*x);
167
168                self.n += 1;
169
170                self.sum_x += xf;
171                self.sum_inv_x += xf.recip();
172                self.sum_ln_x += xf.ln();
173            }
174
175            fn forget(&mut self, x: &$kind) {
176                if self.n > 1 {
177                    let xf = f64::from(*x);
178
179                    self.sum_x -= xf;
180                    self.sum_inv_x -= xf.recip();
181                    self.sum_ln_x -= xf.ln();
182                    self.n -= 1;
183                } else {
184                    self.n = 0;
185                    self.sum_x = 0.0;
186                    self.sum_inv_x = 0.0;
187                    self.sum_ln_x = 0.0;
188                }
189            }
190            fn merge(&mut self, other: Self) {
191                self.n += other.n;
192                self.sum_x += other.sum_x;
193                self.sum_inv_x += other.sum_inv_x;
194                self.sum_ln_x += other.sum_ln_x;
195            }
196        }
197    };
198}
199
200impl_invgaussian_suffstat!(f32);
201impl_invgaussian_suffstat!(f64);
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    #[test]
208    fn observe_forget() {
209        let mut stat = InvGaussianSuffStat::new();
210
211        stat.observe(&0.1);
212        stat.observe(&0.2);
213
214        assert_eq!(stat.n(), 2);
215        assert::close(stat.sum_x, 0.1_f64 + 0.2_f64, 1e-10);
216        assert::close(
217            stat.sum_inv_x,
218            (0.1_f64).recip() + (0.2_f64).recip(),
219            1e-10,
220        );
221        assert::close(stat.sum_ln_x, (0.1_f64).ln() + (0.2_f64).ln(), 1e-10);
222
223        stat.forget(&0.1);
224
225        assert_eq!(stat.n(), 1);
226        assert::close(stat.sum_x, 0.2_f64, 1e-10);
227        assert::close(stat.sum_inv_x, (0.2_f64).recip(), 1e-10);
228        assert::close(stat.sum_ln_x, (0.2_f64).ln(), 1e-10);
229
230        stat.forget(&0.2);
231
232        assert_eq!(stat.n(), 0);
233        assert_eq!(stat.sum_ln_x, 0.0);
234    }
235
236    #[test]
237    fn merge() {
238        let mut a = InvGaussianSuffStat::new();
239        let mut b = InvGaussianSuffStat::new();
240        let mut c = InvGaussianSuffStat::new();
241
242        a.observe_many(&[0.1_f64, 0.2, 0.3]);
243        b.observe_many(&[0.9_f64, 0.8, 0.7]);
244
245        c.observe_many(&[0.1_f64, 0.2, 0.3, 0.9, 0.8, 0.7]);
246
247        <InvGaussianSuffStat as SuffStat<f64>>::merge(&mut a, b);
248
249        assert_eq!(a.n(), c.n());
250        assert::close(a.sum_x(), c.sum_x(), 1e-10);
251        assert::close(a.sum_inv_x(), c.sum_inv_x(), 1e-10);
252        assert::close(a.sum_ln_x(), c.sum_ln_x(), 1e-10);
253    }
254}