rv/dist/
normal_inv_gamma.rs

1//! A common conjugate prior for Gaussians with unknown mean and variance
2//!
3//! For a reference see section 6 of [Kevin Murphy's
4//! whitepaper](https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf).
5#[cfg(feature = "serde1")]
6use serde::{Deserialize, Serialize};
7
8mod gaussian_prior;
9
10use crate::dist::{Gaussian, InvGamma};
11use crate::impl_display;
12use crate::traits::*;
13use rand::Rng;
14use std::fmt;
15
16/// Prior for Gaussian
17///
18/// Given `x ~ N(μ, σ)`, the Normal Inverse Gamma prior implies that
19/// `μ ~ N(m, sqrt(v)σ)` and `ρ ~ InvGamma(a, b)`.
20#[derive(Debug, Clone, PartialEq)]
21#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
22#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
23pub struct NormalInvGamma {
24    m: f64,
25    v: f64,
26    a: f64,
27    b: f64,
28}
29
30pub struct NormalInvGammaParameters {
31    pub m: f64,
32    pub v: f64,
33    pub a: f64,
34    pub b: f64,
35}
36
37impl Parameterized for NormalInvGamma {
38    type Parameters = NormalInvGammaParameters;
39
40    fn emit_params(&self) -> Self::Parameters {
41        Self::Parameters {
42            m: self.m(),
43            v: self.v(),
44            a: self.a(),
45            b: self.b(),
46        }
47    }
48
49    fn from_params(params: Self::Parameters) -> Self {
50        Self::new_unchecked(params.m, params.v, params.a, params.b)
51    }
52}
53
54#[derive(Debug, Clone, PartialEq)]
55#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
56#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
57pub enum NormalInvGammaError {
58    /// The m parameter is infinite or NaN
59    MNotFinite { m: f64 },
60    /// The v parameter is less than or equal to zero
61    VTooLow { v: f64 },
62    /// The v parameter is infinite or NaN
63    VNotFinite { v: f64 },
64    /// The a parameter is less than or equal to zero
65    ATooLow { a: f64 },
66    /// The a parameter is infinite or NaN
67    ANotFinite { a: f64 },
68    /// The b parameter is less than or equal to zero
69    BTooLow { b: f64 },
70    /// The b parameter is infinite or NaN
71    BNotFinite { b: f64 },
72}
73
74impl NormalInvGamma {
75    /// Create a new Normal Inverse Gamma distribution
76    ///
77    /// # Arguments
78    /// - m: The prior mean
79    /// - v: Relative variance of μ versus data
80    /// - a: The mean of variance is b / (a - 1)
81    /// - b: Degrees of freedom of the variance
82    pub fn new(
83        m: f64,
84        v: f64,
85        a: f64,
86        b: f64,
87    ) -> Result<Self, NormalInvGammaError> {
88        if !m.is_finite() {
89            Err(NormalInvGammaError::MNotFinite { m })
90        } else if !v.is_finite() {
91            Err(NormalInvGammaError::VNotFinite { v })
92        } else if !a.is_finite() {
93            Err(NormalInvGammaError::ANotFinite { a })
94        } else if !b.is_finite() {
95            Err(NormalInvGammaError::BNotFinite { b })
96        } else if v <= 0.0 {
97            Err(NormalInvGammaError::VTooLow { v })
98        } else if a <= 0.0 {
99            Err(NormalInvGammaError::ATooLow { a })
100        } else if b <= 0.0 {
101            Err(NormalInvGammaError::BTooLow { b })
102        } else {
103            Ok(NormalInvGamma { m, v, a, b })
104        }
105    }
106
107    /// Creates a new NormalInvGamma without checking whether the parameters are
108    /// valid.
109    #[inline(always)]
110    pub fn new_unchecked(m: f64, v: f64, a: f64, b: f64) -> Self {
111        NormalInvGamma { m, v, a, b }
112    }
113
114    /// Get the m parameter
115    #[inline(always)]
116    pub fn m(&self) -> f64 {
117        self.m
118    }
119
120    /// Set the value of m
121    ///
122    /// # Example
123    ///
124    /// ```rust
125    /// use rv::dist::NormalInvGamma;
126    ///
127    /// let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
128    /// assert_eq!(nig.m(), 0.0);
129    ///
130    /// nig.set_m(-1.1).unwrap();
131    /// assert_eq!(nig.m(), -1.1);
132    /// ```
133    ///
134    /// Will error for invalid values
135    ///
136    /// ```rust
137    /// # use rv::dist::NormalInvGamma;
138    /// # let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
139    /// assert!(nig.set_m(-1.1).is_ok());
140    /// assert!(nig.set_m(f64::INFINITY).is_err());
141    /// assert!(nig.set_m(f64::NEG_INFINITY).is_err());
142    /// assert!(nig.set_m(f64::NAN).is_err());
143    /// ```
144    #[inline]
145    pub fn set_m(&mut self, m: f64) -> Result<(), NormalInvGammaError> {
146        if m.is_finite() {
147            self.set_m_unchecked(m);
148            Ok(())
149        } else {
150            Err(NormalInvGammaError::MNotFinite { m })
151        }
152    }
153
154    /// Set the value of m without input validation
155    #[inline(always)]
156    pub fn set_m_unchecked(&mut self, m: f64) {
157        self.m = m;
158    }
159
160    /// Get the v parameter
161    #[inline]
162    pub fn v(&self) -> f64 {
163        self.v
164    }
165
166    /// Set the value of v
167    ///
168    /// # Example
169    ///
170    /// ```rust
171    /// use rv::dist::NormalInvGamma;
172    ///
173    /// let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
174    /// assert_eq!(nig.v(), 1.2);
175    ///
176    /// nig.set_v(4.3).unwrap();
177    /// assert_eq!(nig.v(), 4.3);
178    /// ```
179    ///
180    /// Will error for invalid values
181    ///
182    /// ```rust
183    /// # use rv::dist::NormalInvGamma;
184    /// # let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
185    /// assert!(nig.set_v(2.1).is_ok());
186    ///
187    /// // must be greater than zero
188    /// assert!(nig.set_v(0.0).is_err());
189    /// assert!(nig.set_v(-1.0).is_err());
190    ///
191    ///
192    /// assert!(nig.set_v(f64::INFINITY).is_err());
193    /// assert!(nig.set_v(f64::NEG_INFINITY).is_err());
194    /// assert!(nig.set_v(f64::NAN).is_err());
195    /// ```
196    #[inline]
197    pub fn set_v(&mut self, v: f64) -> Result<(), NormalInvGammaError> {
198        if !v.is_finite() {
199            Err(NormalInvGammaError::VNotFinite { v })
200        } else if v <= 0.0 {
201            Err(NormalInvGammaError::VTooLow { v })
202        } else {
203            self.set_v_unchecked(v);
204            Ok(())
205        }
206    }
207
208    /// Set the value of v without input validation
209    #[inline]
210    pub fn set_v_unchecked(&mut self, v: f64) {
211        self.v = v;
212    }
213
214    /// Get the a parameter
215    #[inline]
216    pub fn a(&self) -> f64 {
217        self.a
218    }
219
220    /// Set the value of a
221    ///
222    /// # Example
223    ///
224    /// ```rust
225    /// use rv::dist::NormalInvGamma;
226    ///
227    /// let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
228    /// assert_eq!(nig.a(), 2.3);
229    ///
230    /// nig.set_a(4.3).unwrap();
231    /// assert_eq!(nig.a(), 4.3);
232    /// ```
233    ///
234    /// Will error for invalid values
235    ///
236    /// ```rust
237    /// # use rv::dist::NormalInvGamma;
238    /// # let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
239    /// assert!(nig.set_a(2.1).is_ok());
240    ///
241    /// // must be greater than zero
242    /// assert!(nig.set_a(0.0).is_err());
243    /// assert!(nig.set_a(-1.0).is_err());
244    ///
245    ///
246    /// assert!(nig.set_a(f64::INFINITY).is_err());
247    /// assert!(nig.set_a(f64::NEG_INFINITY).is_err());
248    /// assert!(nig.set_a(f64::NAN).is_err());
249    /// ```
250    #[inline]
251    pub fn set_a(&mut self, a: f64) -> Result<(), NormalInvGammaError> {
252        if !a.is_finite() {
253            Err(NormalInvGammaError::ANotFinite { a })
254        } else if a <= 0.0 {
255            Err(NormalInvGammaError::ATooLow { a })
256        } else {
257            self.set_a_unchecked(a);
258            Ok(())
259        }
260    }
261
262    /// Set the value of a without input validation
263    #[inline]
264    pub fn set_a_unchecked(&mut self, a: f64) {
265        self.a = a;
266    }
267
268    /// Get the b parameter
269    #[inline]
270    pub fn b(&self) -> f64 {
271        self.b
272    }
273
274    /// Set the value of b
275    ///
276    /// # Example
277    ///
278    /// ```rust
279    /// use rv::dist::NormalInvGamma;
280    ///
281    /// let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
282    /// assert_eq!(nig.b(), 3.4);
283    ///
284    /// nig.set_b(4.3).unwrap();
285    /// assert_eq!(nig.b(), 4.3);
286    /// ```
287    ///
288    /// Will error for invalid values
289    ///
290    /// ```rust
291    /// # use rv::dist::NormalInvGamma;
292    /// # let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
293    /// assert!(nig.set_b(2.1).is_ok());
294    ///
295    /// // must be greater than zero
296    /// assert!(nig.set_b(0.0).is_err());
297    /// assert!(nig.set_b(-1.0).is_err());
298    ///
299    ///
300    /// assert!(nig.set_b(f64::INFINITY).is_err());
301    /// assert!(nig.set_b(f64::NEG_INFINITY).is_err());
302    /// assert!(nig.set_b(f64::NAN).is_err());
303    /// ```
304    #[inline]
305    pub fn set_b(&mut self, b: f64) -> Result<(), NormalInvGammaError> {
306        if !b.is_finite() {
307            Err(NormalInvGammaError::BNotFinite { b })
308        } else if b <= 0.0 {
309            Err(NormalInvGammaError::BTooLow { b })
310        } else {
311            self.set_b_unchecked(b);
312            Ok(())
313        }
314    }
315
316    /// Set the value of b without input validation
317    #[inline(always)]
318    pub fn set_b_unchecked(&mut self, b: f64) {
319        self.b = b;
320    }
321}
322
323impl From<&NormalInvGamma> for String {
324    fn from(nig: &NormalInvGamma) -> String {
325        format!(
326            "Normal-Inverse-Gamma(m: {}, v: {}, a: {}, b: {})",
327            nig.m, nig.v, nig.a, nig.b
328        )
329    }
330}
331
332impl_display!(NormalInvGamma);
333
334impl HasDensity<Gaussian> for NormalInvGamma {
335    fn ln_f(&self, x: &Gaussian) -> f64 {
336        // TODO: could cache the gamma and Gaussian distributions
337        let mu = x.mu();
338        let sigma = x.sigma();
339        let lnf_sigma =
340            InvGamma::new_unchecked(self.a, self.b).ln_f(&(sigma * sigma));
341        let prior_sigma = self.v.sqrt() * sigma;
342        let lnf_mu = Gaussian::new_unchecked(self.m, prior_sigma).ln_f(&mu);
343        lnf_sigma + lnf_mu
344    }
345}
346
347impl Sampleable<Gaussian> for NormalInvGamma {
348    fn draw<R: Rng>(&self, mut rng: &mut R) -> Gaussian {
349        // NOTE: The parameter errors in this fn shouldn't happen if the prior
350        // parameters are valid.
351        let var: f64 = InvGamma::new(self.a, self.b)
352            .map_err(|err| {
353                panic!("Invalid σ² params when drawing Gaussian: {}", err)
354            })
355            .unwrap()
356            .draw(&mut rng);
357
358        let sigma = if var <= 0.0 { f64::EPSILON } else { var.sqrt() };
359
360        let post_sigma: f64 = self.v.sqrt() * sigma;
361        let mu: f64 = Gaussian::new(self.m, post_sigma)
362            .map_err(|err| {
363                panic!("Invalid μ params when drawing Gaussian: {}", err)
364            })
365            .unwrap()
366            .draw(&mut rng);
367
368        Gaussian::new(mu, sigma).expect("Invalid params")
369    }
370}
371
372impl std::error::Error for NormalInvGammaError {}
373
374impl fmt::Display for NormalInvGammaError {
375    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
376        match self {
377            Self::MNotFinite { m } => write!(f, "non-finite m: {}", m),
378            Self::VNotFinite { v } => write!(f, "non-finite v: {}", v),
379            Self::ANotFinite { a } => write!(f, "non-finite a: {}", a),
380            Self::BNotFinite { b } => write!(f, "non-finite b: {}", b),
381            Self::VTooLow { v } => {
382                write!(f, "v ({}) must be greater than zero", v)
383            }
384            Self::ATooLow { a } => {
385                write!(f, "a ({}) must be greater than zero", a)
386            }
387            Self::BTooLow { b } => {
388                write!(f, "b ({}) must be greater than zero", b)
389            }
390        }
391    }
392}
393
394#[cfg(test)]
395mod test {
396    use super::*;
397    use crate::test_basic_impls;
398
399    test_basic_impls!(
400        Gaussian,
401        NormalInvGamma,
402        NormalInvGamma::new(0.1, 1.2, 2.3, 3.4).unwrap()
403    );
404}