Skip to main content

rv/dist/
normal_inv_gamma.rs

1//! A common conjugate prior for Gaussians with unknown mean and variance
2//!
3//! For a reference see section 6 of [Kevin Murphy's
4//! whitepaper](https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf).
5#[cfg(feature = "serde1")]
6use serde::{Deserialize, Serialize};
7
8mod gaussian_prior;
9
10use rand::Rng;
11use std::fmt;
12
13use crate::dist::{Gaussian, InvGamma};
14use crate::impl_display;
15use crate::traits::{HasDensity, Parameterized, Sampleable};
16
17/// Prior for Gaussian
18///
19/// Given `x ~ N(μ, σ)`, the Normal Inverse Gamma prior implies that
20/// `μ ~ N(m, sqrt(v)σ)` and `ρ ~ InvGamma(a, b)`.
21#[derive(Debug, Clone, PartialEq)]
22#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
23#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
24pub struct NormalInvGamma {
25    m: f64,
26    v: f64,
27    a: f64,
28    b: f64,
29}
30
31pub struct NormalInvGammaParameters {
32    pub m: f64,
33    pub v: f64,
34    pub a: f64,
35    pub b: f64,
36}
37
38impl Parameterized for NormalInvGamma {
39    type Parameters = NormalInvGammaParameters;
40
41    fn emit_params(&self) -> Self::Parameters {
42        Self::Parameters {
43            m: self.m(),
44            v: self.v(),
45            a: self.a(),
46            b: self.b(),
47        }
48    }
49
50    fn from_params(params: Self::Parameters) -> Self {
51        Self::new_unchecked(params.m, params.v, params.a, params.b)
52    }
53}
54
55#[derive(Debug, Clone, PartialEq)]
56#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
57#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
58pub enum NormalInvGammaError {
59    /// The m parameter is infinite or NaN
60    MNotFinite { m: f64 },
61    /// The v parameter is less than or equal to zero
62    VTooLow { v: f64 },
63    /// The v parameter is infinite or NaN
64    VNotFinite { v: f64 },
65    /// The a parameter is less than or equal to zero
66    ATooLow { a: f64 },
67    /// The a parameter is infinite or NaN
68    ANotFinite { a: f64 },
69    /// The b parameter is less than or equal to zero
70    BTooLow { b: f64 },
71    /// The b parameter is infinite or NaN
72    BNotFinite { b: f64 },
73}
74
75impl NormalInvGamma {
76    /// Create a new Normal Inverse Gamma distribution
77    ///
78    /// # Arguments
79    /// - m: The prior mean
80    /// - v: Relative variance of μ versus data
81    /// - a: The mean of variance is b / (a - 1)
82    /// - b: Degrees of freedom of the variance
83    pub fn new(
84        m: f64,
85        v: f64,
86        a: f64,
87        b: f64,
88    ) -> Result<Self, NormalInvGammaError> {
89        if !m.is_finite() {
90            Err(NormalInvGammaError::MNotFinite { m })
91        } else if !v.is_finite() {
92            Err(NormalInvGammaError::VNotFinite { v })
93        } else if !a.is_finite() {
94            Err(NormalInvGammaError::ANotFinite { a })
95        } else if !b.is_finite() {
96            Err(NormalInvGammaError::BNotFinite { b })
97        } else if v <= 0.0 {
98            Err(NormalInvGammaError::VTooLow { v })
99        } else if a <= 0.0 {
100            Err(NormalInvGammaError::ATooLow { a })
101        } else if b <= 0.0 {
102            Err(NormalInvGammaError::BTooLow { b })
103        } else {
104            Ok(NormalInvGamma { m, v, a, b })
105        }
106    }
107
108    /// Creates a new `NormalInvGamma` without checking whether the parameters are
109    /// valid.
110    #[inline(always)]
111    #[must_use]
112    pub fn new_unchecked(m: f64, v: f64, a: f64, b: f64) -> Self {
113        NormalInvGamma { m, v, a, b }
114    }
115
116    /// Get the m parameter
117    #[inline(always)]
118    #[must_use]
119    pub fn m(&self) -> f64 {
120        self.m
121    }
122
123    /// Set the value of m
124    ///
125    /// # Example
126    ///
127    /// ```rust
128    /// use rv::dist::NormalInvGamma;
129    ///
130    /// let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
131    /// assert_eq!(nig.m(), 0.0);
132    ///
133    /// nig.set_m(-1.1).unwrap();
134    /// assert_eq!(nig.m(), -1.1);
135    /// ```
136    ///
137    /// Will error for invalid values
138    ///
139    /// ```rust
140    /// # use rv::dist::NormalInvGamma;
141    /// # let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
142    /// assert!(nig.set_m(-1.1).is_ok());
143    /// assert!(nig.set_m(f64::INFINITY).is_err());
144    /// assert!(nig.set_m(f64::NEG_INFINITY).is_err());
145    /// assert!(nig.set_m(f64::NAN).is_err());
146    /// ```
147    #[inline]
148    pub fn set_m(&mut self, m: f64) -> Result<(), NormalInvGammaError> {
149        if m.is_finite() {
150            self.set_m_unchecked(m);
151            Ok(())
152        } else {
153            Err(NormalInvGammaError::MNotFinite { m })
154        }
155    }
156
157    /// Set the value of m without input validation
158    #[inline(always)]
159    pub fn set_m_unchecked(&mut self, m: f64) {
160        self.m = m;
161    }
162
163    /// Get the v parameter
164    #[inline]
165    #[must_use]
166    pub fn v(&self) -> f64 {
167        self.v
168    }
169
170    /// Set the value of v
171    ///
172    /// # Example
173    ///
174    /// ```rust
175    /// use rv::dist::NormalInvGamma;
176    ///
177    /// let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
178    /// assert_eq!(nig.v(), 1.2);
179    ///
180    /// nig.set_v(4.3).unwrap();
181    /// assert_eq!(nig.v(), 4.3);
182    /// ```
183    ///
184    /// Will error for invalid values
185    ///
186    /// ```rust
187    /// # use rv::dist::NormalInvGamma;
188    /// # let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
189    /// assert!(nig.set_v(2.1).is_ok());
190    ///
191    /// // must be greater than zero
192    /// assert!(nig.set_v(0.0).is_err());
193    /// assert!(nig.set_v(-1.0).is_err());
194    ///
195    ///
196    /// assert!(nig.set_v(f64::INFINITY).is_err());
197    /// assert!(nig.set_v(f64::NEG_INFINITY).is_err());
198    /// assert!(nig.set_v(f64::NAN).is_err());
199    /// ```
200    #[inline]
201    pub fn set_v(&mut self, v: f64) -> Result<(), NormalInvGammaError> {
202        if !v.is_finite() {
203            Err(NormalInvGammaError::VNotFinite { v })
204        } else if v <= 0.0 {
205            Err(NormalInvGammaError::VTooLow { v })
206        } else {
207            self.set_v_unchecked(v);
208            Ok(())
209        }
210    }
211
212    /// Set the value of v without input validation
213    #[inline]
214    pub fn set_v_unchecked(&mut self, v: f64) {
215        self.v = v;
216    }
217
218    /// Get the a parameter
219    #[inline]
220    #[must_use]
221    pub fn a(&self) -> f64 {
222        self.a
223    }
224
225    /// Set the value of a
226    ///
227    /// # Example
228    ///
229    /// ```rust
230    /// use rv::dist::NormalInvGamma;
231    ///
232    /// let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
233    /// assert_eq!(nig.a(), 2.3);
234    ///
235    /// nig.set_a(4.3).unwrap();
236    /// assert_eq!(nig.a(), 4.3);
237    /// ```
238    ///
239    /// Will error for invalid values
240    ///
241    /// ```rust
242    /// # use rv::dist::NormalInvGamma;
243    /// # let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
244    /// assert!(nig.set_a(2.1).is_ok());
245    ///
246    /// // must be greater than zero
247    /// assert!(nig.set_a(0.0).is_err());
248    /// assert!(nig.set_a(-1.0).is_err());
249    ///
250    ///
251    /// assert!(nig.set_a(f64::INFINITY).is_err());
252    /// assert!(nig.set_a(f64::NEG_INFINITY).is_err());
253    /// assert!(nig.set_a(f64::NAN).is_err());
254    /// ```
255    #[inline]
256    pub fn set_a(&mut self, a: f64) -> Result<(), NormalInvGammaError> {
257        if !a.is_finite() {
258            Err(NormalInvGammaError::ANotFinite { a })
259        } else if a <= 0.0 {
260            Err(NormalInvGammaError::ATooLow { a })
261        } else {
262            self.set_a_unchecked(a);
263            Ok(())
264        }
265    }
266
267    /// Set the value of a without input validation
268    #[inline]
269    pub fn set_a_unchecked(&mut self, a: f64) {
270        self.a = a;
271    }
272
273    /// Get the b parameter
274    #[inline]
275    #[must_use]
276    pub fn b(&self) -> f64 {
277        self.b
278    }
279
280    /// Set the value of b
281    ///
282    /// # Example
283    ///
284    /// ```rust
285    /// use rv::dist::NormalInvGamma;
286    ///
287    /// let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
288    /// assert_eq!(nig.b(), 3.4);
289    ///
290    /// nig.set_b(4.3).unwrap();
291    /// assert_eq!(nig.b(), 4.3);
292    /// ```
293    ///
294    /// Will error for invalid values
295    ///
296    /// ```rust
297    /// # use rv::dist::NormalInvGamma;
298    /// # let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
299    /// assert!(nig.set_b(2.1).is_ok());
300    ///
301    /// // must be greater than zero
302    /// assert!(nig.set_b(0.0).is_err());
303    /// assert!(nig.set_b(-1.0).is_err());
304    ///
305    ///
306    /// assert!(nig.set_b(f64::INFINITY).is_err());
307    /// assert!(nig.set_b(f64::NEG_INFINITY).is_err());
308    /// assert!(nig.set_b(f64::NAN).is_err());
309    /// ```
310    #[inline]
311    pub fn set_b(&mut self, b: f64) -> Result<(), NormalInvGammaError> {
312        if !b.is_finite() {
313            Err(NormalInvGammaError::BNotFinite { b })
314        } else if b <= 0.0 {
315            Err(NormalInvGammaError::BTooLow { b })
316        } else {
317            self.set_b_unchecked(b);
318            Ok(())
319        }
320    }
321
322    /// Set the value of b without input validation
323    #[inline(always)]
324    pub fn set_b_unchecked(&mut self, b: f64) {
325        self.b = b;
326    }
327}
328
329impl From<&NormalInvGamma> for String {
330    fn from(nig: &NormalInvGamma) -> String {
331        format!(
332            "Normal-Inverse-Gamma(m: {}, v: {}, a: {}, b: {})",
333            nig.m, nig.v, nig.a, nig.b
334        )
335    }
336}
337
338impl_display!(NormalInvGamma);
339
340impl HasDensity<Gaussian> for NormalInvGamma {
341    fn ln_f(&self, x: &Gaussian) -> f64 {
342        // TODO: could cache the gamma and Gaussian distributions
343        let mu = x.mu();
344        let sigma = x.sigma();
345        let lnf_sigma =
346            InvGamma::new_unchecked(self.a, self.b).ln_f(&(sigma * sigma));
347        let prior_sigma = self.v.sqrt() * sigma;
348        let lnf_mu = Gaussian::new_unchecked(self.m, prior_sigma).ln_f(&mu);
349        lnf_sigma + lnf_mu
350    }
351}
352
353impl Sampleable<Gaussian> for NormalInvGamma {
354    fn draw<R: Rng>(&self, mut rng: &mut R) -> Gaussian {
355        // NOTE: The parameter errors in this fn shouldn't happen if the prior
356        // parameters are valid.
357        let var: f64 = InvGamma::new(self.a, self.b)
358            .map_err(|err| {
359                panic!("Invalid σ² params when drawing Gaussian: {err}")
360            })
361            .unwrap()
362            .draw(&mut rng);
363
364        let sigma = if var <= 0.0 { f64::EPSILON } else { var.sqrt() };
365
366        let post_sigma: f64 = self.v.sqrt() * sigma;
367        let mu: f64 = Gaussian::new(self.m, post_sigma)
368            .map_err(|err| {
369                panic!("Invalid μ params when drawing Gaussian: {err}")
370            })
371            .unwrap()
372            .draw(&mut rng);
373
374        Gaussian::new(mu, sigma).expect("Invalid params")
375    }
376}
377
378impl std::error::Error for NormalInvGammaError {}
379
380#[cfg_attr(coverage_nightly, coverage(off))]
381impl fmt::Display for NormalInvGammaError {
382    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
383        match self {
384            Self::MNotFinite { m } => write!(f, "non-finite m: {m}"),
385            Self::VNotFinite { v } => write!(f, "non-finite v: {v}"),
386            Self::ANotFinite { a } => write!(f, "non-finite a: {a}"),
387            Self::BNotFinite { b } => write!(f, "non-finite b: {b}"),
388            Self::VTooLow { v } => {
389                write!(f, "v ({v}) must be greater than zero")
390            }
391            Self::ATooLow { a } => {
392                write!(f, "a ({a}) must be greater than zero")
393            }
394            Self::BTooLow { b } => {
395                write!(f, "b ({b}) must be greater than zero")
396            }
397        }
398    }
399}
400
401#[cfg(test)]
402mod test {
403    use super::*;
404    use crate::test_basic_impls;
405
406    test_basic_impls!(
407        Gaussian,
408        NormalInvGamma,
409        NormalInvGamma::new(0.1, 1.2, 2.3, 3.4).unwrap()
410    );
411}