rv/dist/normal_inv_gamma.rs
1//! A common conjugate prior for Gaussians with unknown mean and variance
2//!
3//! For a reference see section 6 of [Kevin Murphy's
4//! whitepaper](https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf).
5#[cfg(feature = "serde1")]
6use serde::{Deserialize, Serialize};
7
8mod gaussian_prior;
9
10use rand::Rng;
11use std::fmt;
12
13use crate::dist::{Gaussian, InvGamma};
14use crate::impl_display;
15use crate::traits::{HasDensity, Parameterized, Sampleable};
16
17/// Prior for Gaussian
18///
19/// Given `x ~ N(μ, σ)`, the Normal Inverse Gamma prior implies that
20/// `μ ~ N(m, sqrt(v)σ)` and `ρ ~ InvGamma(a, b)`.
21#[derive(Debug, Clone, PartialEq)]
22#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
23#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
24pub struct NormalInvGamma {
25 m: f64,
26 v: f64,
27 a: f64,
28 b: f64,
29}
30
31pub struct NormalInvGammaParameters {
32 pub m: f64,
33 pub v: f64,
34 pub a: f64,
35 pub b: f64,
36}
37
38impl Parameterized for NormalInvGamma {
39 type Parameters = NormalInvGammaParameters;
40
41 fn emit_params(&self) -> Self::Parameters {
42 Self::Parameters {
43 m: self.m(),
44 v: self.v(),
45 a: self.a(),
46 b: self.b(),
47 }
48 }
49
50 fn from_params(params: Self::Parameters) -> Self {
51 Self::new_unchecked(params.m, params.v, params.a, params.b)
52 }
53}
54
55#[derive(Debug, Clone, PartialEq)]
56#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
57#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
58pub enum NormalInvGammaError {
59 /// The m parameter is infinite or NaN
60 MNotFinite { m: f64 },
61 /// The v parameter is less than or equal to zero
62 VTooLow { v: f64 },
63 /// The v parameter is infinite or NaN
64 VNotFinite { v: f64 },
65 /// The a parameter is less than or equal to zero
66 ATooLow { a: f64 },
67 /// The a parameter is infinite or NaN
68 ANotFinite { a: f64 },
69 /// The b parameter is less than or equal to zero
70 BTooLow { b: f64 },
71 /// The b parameter is infinite or NaN
72 BNotFinite { b: f64 },
73}
74
75impl NormalInvGamma {
76 /// Create a new Normal Inverse Gamma distribution
77 ///
78 /// # Arguments
79 /// - m: The prior mean
80 /// - v: Relative variance of μ versus data
81 /// - a: The mean of variance is b / (a - 1)
82 /// - b: Degrees of freedom of the variance
83 pub fn new(
84 m: f64,
85 v: f64,
86 a: f64,
87 b: f64,
88 ) -> Result<Self, NormalInvGammaError> {
89 if !m.is_finite() {
90 Err(NormalInvGammaError::MNotFinite { m })
91 } else if !v.is_finite() {
92 Err(NormalInvGammaError::VNotFinite { v })
93 } else if !a.is_finite() {
94 Err(NormalInvGammaError::ANotFinite { a })
95 } else if !b.is_finite() {
96 Err(NormalInvGammaError::BNotFinite { b })
97 } else if v <= 0.0 {
98 Err(NormalInvGammaError::VTooLow { v })
99 } else if a <= 0.0 {
100 Err(NormalInvGammaError::ATooLow { a })
101 } else if b <= 0.0 {
102 Err(NormalInvGammaError::BTooLow { b })
103 } else {
104 Ok(NormalInvGamma { m, v, a, b })
105 }
106 }
107
108 /// Creates a new `NormalInvGamma` without checking whether the parameters are
109 /// valid.
110 #[inline(always)]
111 #[must_use]
112 pub fn new_unchecked(m: f64, v: f64, a: f64, b: f64) -> Self {
113 NormalInvGamma { m, v, a, b }
114 }
115
116 /// Get the m parameter
117 #[inline(always)]
118 #[must_use]
119 pub fn m(&self) -> f64 {
120 self.m
121 }
122
123 /// Set the value of m
124 ///
125 /// # Example
126 ///
127 /// ```rust
128 /// use rv::dist::NormalInvGamma;
129 ///
130 /// let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
131 /// assert_eq!(nig.m(), 0.0);
132 ///
133 /// nig.set_m(-1.1).unwrap();
134 /// assert_eq!(nig.m(), -1.1);
135 /// ```
136 ///
137 /// Will error for invalid values
138 ///
139 /// ```rust
140 /// # use rv::dist::NormalInvGamma;
141 /// # let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
142 /// assert!(nig.set_m(-1.1).is_ok());
143 /// assert!(nig.set_m(f64::INFINITY).is_err());
144 /// assert!(nig.set_m(f64::NEG_INFINITY).is_err());
145 /// assert!(nig.set_m(f64::NAN).is_err());
146 /// ```
147 #[inline]
148 pub fn set_m(&mut self, m: f64) -> Result<(), NormalInvGammaError> {
149 if m.is_finite() {
150 self.set_m_unchecked(m);
151 Ok(())
152 } else {
153 Err(NormalInvGammaError::MNotFinite { m })
154 }
155 }
156
157 /// Set the value of m without input validation
158 #[inline(always)]
159 pub fn set_m_unchecked(&mut self, m: f64) {
160 self.m = m;
161 }
162
163 /// Get the v parameter
164 #[inline]
165 #[must_use]
166 pub fn v(&self) -> f64 {
167 self.v
168 }
169
170 /// Set the value of v
171 ///
172 /// # Example
173 ///
174 /// ```rust
175 /// use rv::dist::NormalInvGamma;
176 ///
177 /// let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
178 /// assert_eq!(nig.v(), 1.2);
179 ///
180 /// nig.set_v(4.3).unwrap();
181 /// assert_eq!(nig.v(), 4.3);
182 /// ```
183 ///
184 /// Will error for invalid values
185 ///
186 /// ```rust
187 /// # use rv::dist::NormalInvGamma;
188 /// # let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
189 /// assert!(nig.set_v(2.1).is_ok());
190 ///
191 /// // must be greater than zero
192 /// assert!(nig.set_v(0.0).is_err());
193 /// assert!(nig.set_v(-1.0).is_err());
194 ///
195 ///
196 /// assert!(nig.set_v(f64::INFINITY).is_err());
197 /// assert!(nig.set_v(f64::NEG_INFINITY).is_err());
198 /// assert!(nig.set_v(f64::NAN).is_err());
199 /// ```
200 #[inline]
201 pub fn set_v(&mut self, v: f64) -> Result<(), NormalInvGammaError> {
202 if !v.is_finite() {
203 Err(NormalInvGammaError::VNotFinite { v })
204 } else if v <= 0.0 {
205 Err(NormalInvGammaError::VTooLow { v })
206 } else {
207 self.set_v_unchecked(v);
208 Ok(())
209 }
210 }
211
212 /// Set the value of v without input validation
213 #[inline]
214 pub fn set_v_unchecked(&mut self, v: f64) {
215 self.v = v;
216 }
217
218 /// Get the a parameter
219 #[inline]
220 #[must_use]
221 pub fn a(&self) -> f64 {
222 self.a
223 }
224
225 /// Set the value of a
226 ///
227 /// # Example
228 ///
229 /// ```rust
230 /// use rv::dist::NormalInvGamma;
231 ///
232 /// let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
233 /// assert_eq!(nig.a(), 2.3);
234 ///
235 /// nig.set_a(4.3).unwrap();
236 /// assert_eq!(nig.a(), 4.3);
237 /// ```
238 ///
239 /// Will error for invalid values
240 ///
241 /// ```rust
242 /// # use rv::dist::NormalInvGamma;
243 /// # let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
244 /// assert!(nig.set_a(2.1).is_ok());
245 ///
246 /// // must be greater than zero
247 /// assert!(nig.set_a(0.0).is_err());
248 /// assert!(nig.set_a(-1.0).is_err());
249 ///
250 ///
251 /// assert!(nig.set_a(f64::INFINITY).is_err());
252 /// assert!(nig.set_a(f64::NEG_INFINITY).is_err());
253 /// assert!(nig.set_a(f64::NAN).is_err());
254 /// ```
255 #[inline]
256 pub fn set_a(&mut self, a: f64) -> Result<(), NormalInvGammaError> {
257 if !a.is_finite() {
258 Err(NormalInvGammaError::ANotFinite { a })
259 } else if a <= 0.0 {
260 Err(NormalInvGammaError::ATooLow { a })
261 } else {
262 self.set_a_unchecked(a);
263 Ok(())
264 }
265 }
266
267 /// Set the value of a without input validation
268 #[inline]
269 pub fn set_a_unchecked(&mut self, a: f64) {
270 self.a = a;
271 }
272
273 /// Get the b parameter
274 #[inline]
275 #[must_use]
276 pub fn b(&self) -> f64 {
277 self.b
278 }
279
280 /// Set the value of b
281 ///
282 /// # Example
283 ///
284 /// ```rust
285 /// use rv::dist::NormalInvGamma;
286 ///
287 /// let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
288 /// assert_eq!(nig.b(), 3.4);
289 ///
290 /// nig.set_b(4.3).unwrap();
291 /// assert_eq!(nig.b(), 4.3);
292 /// ```
293 ///
294 /// Will error for invalid values
295 ///
296 /// ```rust
297 /// # use rv::dist::NormalInvGamma;
298 /// # let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
299 /// assert!(nig.set_b(2.1).is_ok());
300 ///
301 /// // must be greater than zero
302 /// assert!(nig.set_b(0.0).is_err());
303 /// assert!(nig.set_b(-1.0).is_err());
304 ///
305 ///
306 /// assert!(nig.set_b(f64::INFINITY).is_err());
307 /// assert!(nig.set_b(f64::NEG_INFINITY).is_err());
308 /// assert!(nig.set_b(f64::NAN).is_err());
309 /// ```
310 #[inline]
311 pub fn set_b(&mut self, b: f64) -> Result<(), NormalInvGammaError> {
312 if !b.is_finite() {
313 Err(NormalInvGammaError::BNotFinite { b })
314 } else if b <= 0.0 {
315 Err(NormalInvGammaError::BTooLow { b })
316 } else {
317 self.set_b_unchecked(b);
318 Ok(())
319 }
320 }
321
322 /// Set the value of b without input validation
323 #[inline(always)]
324 pub fn set_b_unchecked(&mut self, b: f64) {
325 self.b = b;
326 }
327}
328
329impl From<&NormalInvGamma> for String {
330 fn from(nig: &NormalInvGamma) -> String {
331 format!(
332 "Normal-Inverse-Gamma(m: {}, v: {}, a: {}, b: {})",
333 nig.m, nig.v, nig.a, nig.b
334 )
335 }
336}
337
338impl_display!(NormalInvGamma);
339
340impl HasDensity<Gaussian> for NormalInvGamma {
341 fn ln_f(&self, x: &Gaussian) -> f64 {
342 // TODO: could cache the gamma and Gaussian distributions
343 let mu = x.mu();
344 let sigma = x.sigma();
345 let lnf_sigma =
346 InvGamma::new_unchecked(self.a, self.b).ln_f(&(sigma * sigma));
347 let prior_sigma = self.v.sqrt() * sigma;
348 let lnf_mu = Gaussian::new_unchecked(self.m, prior_sigma).ln_f(&mu);
349 lnf_sigma + lnf_mu
350 }
351}
352
353impl Sampleable<Gaussian> for NormalInvGamma {
354 fn draw<R: Rng>(&self, mut rng: &mut R) -> Gaussian {
355 // NOTE: The parameter errors in this fn shouldn't happen if the prior
356 // parameters are valid.
357 let var: f64 = InvGamma::new(self.a, self.b)
358 .map_err(|err| {
359 panic!("Invalid σ² params when drawing Gaussian: {err}")
360 })
361 .unwrap()
362 .draw(&mut rng);
363
364 let sigma = if var <= 0.0 { f64::EPSILON } else { var.sqrt() };
365
366 let post_sigma: f64 = self.v.sqrt() * sigma;
367 let mu: f64 = Gaussian::new(self.m, post_sigma)
368 .map_err(|err| {
369 panic!("Invalid μ params when drawing Gaussian: {err}")
370 })
371 .unwrap()
372 .draw(&mut rng);
373
374 Gaussian::new(mu, sigma).expect("Invalid params")
375 }
376}
377
378impl std::error::Error for NormalInvGammaError {}
379
380#[cfg_attr(coverage_nightly, coverage(off))]
381impl fmt::Display for NormalInvGammaError {
382 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
383 match self {
384 Self::MNotFinite { m } => write!(f, "non-finite m: {m}"),
385 Self::VNotFinite { v } => write!(f, "non-finite v: {v}"),
386 Self::ANotFinite { a } => write!(f, "non-finite a: {a}"),
387 Self::BNotFinite { b } => write!(f, "non-finite b: {b}"),
388 Self::VTooLow { v } => {
389 write!(f, "v ({v}) must be greater than zero")
390 }
391 Self::ATooLow { a } => {
392 write!(f, "a ({a}) must be greater than zero")
393 }
394 Self::BTooLow { b } => {
395 write!(f, "b ({b}) must be greater than zero")
396 }
397 }
398 }
399}
400
401#[cfg(test)]
402mod test {
403 use super::*;
404 use crate::test_basic_impls;
405
406 test_basic_impls!(
407 Gaussian,
408 NormalInvGamma,
409 NormalInvGamma::new(0.1, 1.2, 2.3, 3.4).unwrap()
410 );
411}