rv/dist/normal_inv_gamma.rs
1//! A common conjugate prior for Gaussians with unknown mean and variance
2//!
3//! For a reference see section 6 of [Kevin Murphy's
4//! whitepaper](https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf).
5#[cfg(feature = "serde1")]
6use serde::{Deserialize, Serialize};
7
8mod gaussian_prior;
9
10use crate::dist::{Gaussian, InvGamma};
11use crate::impl_display;
12use crate::traits::*;
13use rand::Rng;
14use std::fmt;
15
16/// Prior for Gaussian
17///
18/// Given `x ~ N(μ, σ)`, the Normal Inverse Gamma prior implies that
19/// `μ ~ N(m, sqrt(v)σ)` and `ρ ~ InvGamma(a, b)`.
20#[derive(Debug, Clone, PartialEq)]
21#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
22#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
23pub struct NormalInvGamma {
24 m: f64,
25 v: f64,
26 a: f64,
27 b: f64,
28}
29
30pub struct NormalInvGammaParameters {
31 pub m: f64,
32 pub v: f64,
33 pub a: f64,
34 pub b: f64,
35}
36
37impl Parameterized for NormalInvGamma {
38 type Parameters = NormalInvGammaParameters;
39
40 fn emit_params(&self) -> Self::Parameters {
41 Self::Parameters {
42 m: self.m(),
43 v: self.v(),
44 a: self.a(),
45 b: self.b(),
46 }
47 }
48
49 fn from_params(params: Self::Parameters) -> Self {
50 Self::new_unchecked(params.m, params.v, params.a, params.b)
51 }
52}
53
54#[derive(Debug, Clone, PartialEq)]
55#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
56#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
57pub enum NormalInvGammaError {
58 /// The m parameter is infinite or NaN
59 MNotFinite { m: f64 },
60 /// The v parameter is less than or equal to zero
61 VTooLow { v: f64 },
62 /// The v parameter is infinite or NaN
63 VNotFinite { v: f64 },
64 /// The a parameter is less than or equal to zero
65 ATooLow { a: f64 },
66 /// The a parameter is infinite or NaN
67 ANotFinite { a: f64 },
68 /// The b parameter is less than or equal to zero
69 BTooLow { b: f64 },
70 /// The b parameter is infinite or NaN
71 BNotFinite { b: f64 },
72}
73
74impl NormalInvGamma {
75 /// Create a new Normal Inverse Gamma distribution
76 ///
77 /// # Arguments
78 /// - m: The prior mean
79 /// - v: Relative variance of μ versus data
80 /// - a: The mean of variance is b / (a - 1)
81 /// - b: Degrees of freedom of the variance
82 pub fn new(
83 m: f64,
84 v: f64,
85 a: f64,
86 b: f64,
87 ) -> Result<Self, NormalInvGammaError> {
88 if !m.is_finite() {
89 Err(NormalInvGammaError::MNotFinite { m })
90 } else if !v.is_finite() {
91 Err(NormalInvGammaError::VNotFinite { v })
92 } else if !a.is_finite() {
93 Err(NormalInvGammaError::ANotFinite { a })
94 } else if !b.is_finite() {
95 Err(NormalInvGammaError::BNotFinite { b })
96 } else if v <= 0.0 {
97 Err(NormalInvGammaError::VTooLow { v })
98 } else if a <= 0.0 {
99 Err(NormalInvGammaError::ATooLow { a })
100 } else if b <= 0.0 {
101 Err(NormalInvGammaError::BTooLow { b })
102 } else {
103 Ok(NormalInvGamma { m, v, a, b })
104 }
105 }
106
107 /// Creates a new NormalInvGamma without checking whether the parameters are
108 /// valid.
109 #[inline(always)]
110 pub fn new_unchecked(m: f64, v: f64, a: f64, b: f64) -> Self {
111 NormalInvGamma { m, v, a, b }
112 }
113
114 /// Get the m parameter
115 #[inline(always)]
116 pub fn m(&self) -> f64 {
117 self.m
118 }
119
120 /// Set the value of m
121 ///
122 /// # Example
123 ///
124 /// ```rust
125 /// use rv::dist::NormalInvGamma;
126 ///
127 /// let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
128 /// assert_eq!(nig.m(), 0.0);
129 ///
130 /// nig.set_m(-1.1).unwrap();
131 /// assert_eq!(nig.m(), -1.1);
132 /// ```
133 ///
134 /// Will error for invalid values
135 ///
136 /// ```rust
137 /// # use rv::dist::NormalInvGamma;
138 /// # let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
139 /// assert!(nig.set_m(-1.1).is_ok());
140 /// assert!(nig.set_m(f64::INFINITY).is_err());
141 /// assert!(nig.set_m(f64::NEG_INFINITY).is_err());
142 /// assert!(nig.set_m(f64::NAN).is_err());
143 /// ```
144 #[inline]
145 pub fn set_m(&mut self, m: f64) -> Result<(), NormalInvGammaError> {
146 if m.is_finite() {
147 self.set_m_unchecked(m);
148 Ok(())
149 } else {
150 Err(NormalInvGammaError::MNotFinite { m })
151 }
152 }
153
154 /// Set the value of m without input validation
155 #[inline(always)]
156 pub fn set_m_unchecked(&mut self, m: f64) {
157 self.m = m;
158 }
159
160 /// Get the v parameter
161 #[inline]
162 pub fn v(&self) -> f64 {
163 self.v
164 }
165
166 /// Set the value of v
167 ///
168 /// # Example
169 ///
170 /// ```rust
171 /// use rv::dist::NormalInvGamma;
172 ///
173 /// let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
174 /// assert_eq!(nig.v(), 1.2);
175 ///
176 /// nig.set_v(4.3).unwrap();
177 /// assert_eq!(nig.v(), 4.3);
178 /// ```
179 ///
180 /// Will error for invalid values
181 ///
182 /// ```rust
183 /// # use rv::dist::NormalInvGamma;
184 /// # let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
185 /// assert!(nig.set_v(2.1).is_ok());
186 ///
187 /// // must be greater than zero
188 /// assert!(nig.set_v(0.0).is_err());
189 /// assert!(nig.set_v(-1.0).is_err());
190 ///
191 ///
192 /// assert!(nig.set_v(f64::INFINITY).is_err());
193 /// assert!(nig.set_v(f64::NEG_INFINITY).is_err());
194 /// assert!(nig.set_v(f64::NAN).is_err());
195 /// ```
196 #[inline]
197 pub fn set_v(&mut self, v: f64) -> Result<(), NormalInvGammaError> {
198 if !v.is_finite() {
199 Err(NormalInvGammaError::VNotFinite { v })
200 } else if v <= 0.0 {
201 Err(NormalInvGammaError::VTooLow { v })
202 } else {
203 self.set_v_unchecked(v);
204 Ok(())
205 }
206 }
207
208 /// Set the value of v without input validation
209 #[inline]
210 pub fn set_v_unchecked(&mut self, v: f64) {
211 self.v = v;
212 }
213
214 /// Get the a parameter
215 #[inline]
216 pub fn a(&self) -> f64 {
217 self.a
218 }
219
220 /// Set the value of a
221 ///
222 /// # Example
223 ///
224 /// ```rust
225 /// use rv::dist::NormalInvGamma;
226 ///
227 /// let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
228 /// assert_eq!(nig.a(), 2.3);
229 ///
230 /// nig.set_a(4.3).unwrap();
231 /// assert_eq!(nig.a(), 4.3);
232 /// ```
233 ///
234 /// Will error for invalid values
235 ///
236 /// ```rust
237 /// # use rv::dist::NormalInvGamma;
238 /// # let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
239 /// assert!(nig.set_a(2.1).is_ok());
240 ///
241 /// // must be greater than zero
242 /// assert!(nig.set_a(0.0).is_err());
243 /// assert!(nig.set_a(-1.0).is_err());
244 ///
245 ///
246 /// assert!(nig.set_a(f64::INFINITY).is_err());
247 /// assert!(nig.set_a(f64::NEG_INFINITY).is_err());
248 /// assert!(nig.set_a(f64::NAN).is_err());
249 /// ```
250 #[inline]
251 pub fn set_a(&mut self, a: f64) -> Result<(), NormalInvGammaError> {
252 if !a.is_finite() {
253 Err(NormalInvGammaError::ANotFinite { a })
254 } else if a <= 0.0 {
255 Err(NormalInvGammaError::ATooLow { a })
256 } else {
257 self.set_a_unchecked(a);
258 Ok(())
259 }
260 }
261
262 /// Set the value of a without input validation
263 #[inline]
264 pub fn set_a_unchecked(&mut self, a: f64) {
265 self.a = a;
266 }
267
268 /// Get the b parameter
269 #[inline]
270 pub fn b(&self) -> f64 {
271 self.b
272 }
273
274 /// Set the value of b
275 ///
276 /// # Example
277 ///
278 /// ```rust
279 /// use rv::dist::NormalInvGamma;
280 ///
281 /// let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
282 /// assert_eq!(nig.b(), 3.4);
283 ///
284 /// nig.set_b(4.3).unwrap();
285 /// assert_eq!(nig.b(), 4.3);
286 /// ```
287 ///
288 /// Will error for invalid values
289 ///
290 /// ```rust
291 /// # use rv::dist::NormalInvGamma;
292 /// # let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap();
293 /// assert!(nig.set_b(2.1).is_ok());
294 ///
295 /// // must be greater than zero
296 /// assert!(nig.set_b(0.0).is_err());
297 /// assert!(nig.set_b(-1.0).is_err());
298 ///
299 ///
300 /// assert!(nig.set_b(f64::INFINITY).is_err());
301 /// assert!(nig.set_b(f64::NEG_INFINITY).is_err());
302 /// assert!(nig.set_b(f64::NAN).is_err());
303 /// ```
304 #[inline]
305 pub fn set_b(&mut self, b: f64) -> Result<(), NormalInvGammaError> {
306 if !b.is_finite() {
307 Err(NormalInvGammaError::BNotFinite { b })
308 } else if b <= 0.0 {
309 Err(NormalInvGammaError::BTooLow { b })
310 } else {
311 self.set_b_unchecked(b);
312 Ok(())
313 }
314 }
315
316 /// Set the value of b without input validation
317 #[inline(always)]
318 pub fn set_b_unchecked(&mut self, b: f64) {
319 self.b = b;
320 }
321}
322
323impl From<&NormalInvGamma> for String {
324 fn from(nig: &NormalInvGamma) -> String {
325 format!(
326 "Normal-Inverse-Gamma(m: {}, v: {}, a: {}, b: {})",
327 nig.m, nig.v, nig.a, nig.b
328 )
329 }
330}
331
332impl_display!(NormalInvGamma);
333
334impl HasDensity<Gaussian> for NormalInvGamma {
335 fn ln_f(&self, x: &Gaussian) -> f64 {
336 // TODO: could cache the gamma and Gaussian distributions
337 let mu = x.mu();
338 let sigma = x.sigma();
339 let lnf_sigma =
340 InvGamma::new_unchecked(self.a, self.b).ln_f(&(sigma * sigma));
341 let prior_sigma = self.v.sqrt() * sigma;
342 let lnf_mu = Gaussian::new_unchecked(self.m, prior_sigma).ln_f(&mu);
343 lnf_sigma + lnf_mu
344 }
345}
346
347impl Sampleable<Gaussian> for NormalInvGamma {
348 fn draw<R: Rng>(&self, mut rng: &mut R) -> Gaussian {
349 // NOTE: The parameter errors in this fn shouldn't happen if the prior
350 // parameters are valid.
351 let var: f64 = InvGamma::new(self.a, self.b)
352 .map_err(|err| {
353 panic!("Invalid σ² params when drawing Gaussian: {}", err)
354 })
355 .unwrap()
356 .draw(&mut rng);
357
358 let sigma = if var <= 0.0 { f64::EPSILON } else { var.sqrt() };
359
360 let post_sigma: f64 = self.v.sqrt() * sigma;
361 let mu: f64 = Gaussian::new(self.m, post_sigma)
362 .map_err(|err| {
363 panic!("Invalid μ params when drawing Gaussian: {}", err)
364 })
365 .unwrap()
366 .draw(&mut rng);
367
368 Gaussian::new(mu, sigma).expect("Invalid params")
369 }
370}
371
372impl std::error::Error for NormalInvGammaError {}
373
374impl fmt::Display for NormalInvGammaError {
375 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
376 match self {
377 Self::MNotFinite { m } => write!(f, "non-finite m: {}", m),
378 Self::VNotFinite { v } => write!(f, "non-finite v: {}", v),
379 Self::ANotFinite { a } => write!(f, "non-finite a: {}", a),
380 Self::BNotFinite { b } => write!(f, "non-finite b: {}", b),
381 Self::VTooLow { v } => {
382 write!(f, "v ({}) must be greater than zero", v)
383 }
384 Self::ATooLow { a } => {
385 write!(f, "a ({}) must be greater than zero", a)
386 }
387 Self::BTooLow { b } => {
388 write!(f, "b ({}) must be greater than zero", b)
389 }
390 }
391 }
392}
393
394#[cfg(test)]
395mod test {
396 use super::*;
397 use crate::test_basic_impls;
398
399 test_basic_impls!(
400 Gaussian,
401 NormalInvGamma,
402 NormalInvGamma::new(0.1, 1.2, 2.3, 3.4).unwrap()
403 );
404}