rv 0.15.0-rc.1

Random variables
Documentation
//! A common conjugate prior for Gaussians with unknown mean and variance
//!
//! For a reference see section 6 of [Kevin Murphy's
//! whitepaper](https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf).
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};

mod gaussian_prior;

use crate::dist::{Gaussian, ScaledInvChiSquared};
use crate::impl_display;
use crate::traits::Rv;
use once_cell::sync::OnceCell;
use rand::Rng;

/// Prior for Gaussian
///
/// Given `x ~ N(μ, σ)`, the Normal Inverse Chi Squared prior implies that
/// `μ ~ N(m, σ/√k)` and `σ² ~ ScaledInvChiSquared(v, s2)`.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct NormalInvChiSquared {
    m: f64,
    k: f64,
    v: f64,
    s2: f64,
    /// Cached scaled inv X^2
    #[cfg_attr(feature = "serde1", serde(skip))]
    scaled_inv_x2: OnceCell<ScaledInvChiSquared>,
}

impl PartialEq for NormalInvChiSquared {
    fn eq(&self, other: &Self) -> bool {
        self.m == other.m
            && self.k == other.k
            && self.v == other.v
            && self.s2 == other.s2
    }
}

#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub enum NormalInvChiSquaredError {
    /// The m parameter is infinite or NaN
    MNotFinite { m: f64 },
    /// The k parameter is less than or equal to zero
    KTooLow { k: f64 },
    /// The k parameter is infinite or NaN
    KNotFinite { k: f64 },
    /// The v parameter is less than or equal to zero
    VTooLow { v: f64 },
    /// The v parameter is infinite or NaN
    VNotFinite { v: f64 },
    /// The s2 parameter is less than or equal to zero
    S2TooLow { s2: f64 },
    /// The s2 parameter is infinite or NaN
    S2NotFinite { s2: f64 },
}

impl NormalInvChiSquared {
    /// Create a new Normal Inverse Gamma distribution
    ///
    /// # Arguments
    /// - m: The prior mean
    /// - k: How strongly we believe the prior mean (in prior
    ///      pseudo-observations)
    /// - v: How strongly we believe the prior variance (in prior
    ///      pseudo-observations)
    /// - s2: The prior variance
    pub fn new(
        m: f64,
        k: f64,
        v: f64,
        s2: f64,
    ) -> Result<Self, NormalInvChiSquaredError> {
        if !m.is_finite() {
            Err(NormalInvChiSquaredError::MNotFinite { m })
        } else if !k.is_finite() {
            Err(NormalInvChiSquaredError::KNotFinite { k })
        } else if !v.is_finite() {
            Err(NormalInvChiSquaredError::VNotFinite { v })
        } else if !s2.is_finite() {
            Err(NormalInvChiSquaredError::S2NotFinite { s2 })
        } else if v <= 0.0 {
            Err(NormalInvChiSquaredError::VTooLow { v })
        } else if k <= 0.0 {
            Err(NormalInvChiSquaredError::KTooLow { k })
        } else if s2 <= 0.0 {
            Err(NormalInvChiSquaredError::S2TooLow { s2 })
        } else {
            Ok(NormalInvChiSquared {
                m,
                k,
                v,
                s2,
                scaled_inv_x2: OnceCell::new(),
            })
        }
    }

    /// Creates a new NormalInvChiSquared without checking whether the
    /// parameters are valid.
    #[inline]
    pub fn new_unchecked(m: f64, k: f64, v: f64, s2: f64) -> Self {
        NormalInvChiSquared {
            m,
            k,
            v,
            s2,
            scaled_inv_x2: OnceCell::new(),
        }
    }

    /// Returns (m, k, v, s2)
    #[inline]
    pub fn params(&self) -> (f64, f64, f64, f64) {
        (self.m, self.k, self.v, self.s2)
    }

    /// Get the m parameter
    #[inline]
    pub fn m(&self) -> f64 {
        self.m
    }

    /// Set the value of m
    ///
    /// # Example
    ///
    /// ```rust
    /// use rv::dist::NormalInvChiSquared;
    ///
    /// let mut nix = NormalInvChiSquared::new(0.0, 1.2, 2.3, 3.4).unwrap();
    /// assert_eq!(nix.m(), 0.0);
    ///
    /// nix.set_m(-1.1).unwrap();
    /// assert_eq!(nix.m(), -1.1);
    /// ```
    ///
    /// Will error for invalid values
    ///
    /// ```rust
    /// # use rv::dist::NormalInvChiSquared;
    /// # let mut nix = NormalInvChiSquared::new(0.0, 1.2, 2.3, 3.4).unwrap();
    /// assert!(nix.set_m(-1.1).is_ok());
    /// assert!(nix.set_m(std::f64::INFINITY).is_err());
    /// assert!(nix.set_m(std::f64::NEG_INFINITY).is_err());
    /// assert!(nix.set_m(std::f64::NAN).is_err());
    /// ```
    #[inline]
    pub fn set_m(&mut self, m: f64) -> Result<(), NormalInvChiSquaredError> {
        if m.is_finite() {
            self.set_m_unchecked(m);
            Ok(())
        } else {
            Err(NormalInvChiSquaredError::MNotFinite { m })
        }
    }

    /// Set the value of m without input validation
    #[inline]
    pub fn set_m_unchecked(&mut self, m: f64) {
        self.m = m;
    }

    /// Get the k parameter
    #[inline]
    pub fn k(&self) -> f64 {
        self.k
    }

    /// Set the value of k
    ///
    /// # Example
    ///
    /// ```rust
    /// use rv::dist::NormalInvChiSquared;
    ///
    /// let mut nix = NormalInvChiSquared::new(0.0, 1.2, 2.3, 3.4).unwrap();
    /// assert_eq!(nix.k(), 1.2);
    ///
    /// nix.set_k(4.3).unwrap();
    /// assert_eq!(nix.k(), 4.3);
    /// ```
    ///
    /// Will error for invalid values
    ///
    /// ```rust
    /// # use rv::dist::NormalInvChiSquared;
    /// # let mut nix = NormalInvChiSquared::new(0.0, 1.2, 2.3, 3.4).unwrap();
    /// assert!(nix.set_k(2.1).is_ok());
    ///
    /// // must be greater than zero
    /// assert!(nix.set_k(0.0).is_err());
    /// assert!(nix.set_k(-1.0).is_err());
    ///
    ///
    /// assert!(nix.set_k(std::f64::INFINITY).is_err());
    /// assert!(nix.set_k(std::f64::NEG_INFINITY).is_err());
    /// assert!(nix.set_k(std::f64::NAN).is_err());
    /// ```
    #[inline]
    pub fn set_k(&mut self, k: f64) -> Result<(), NormalInvChiSquaredError> {
        if !k.is_finite() {
            Err(NormalInvChiSquaredError::KNotFinite { k })
        } else if k <= 0.0 {
            Err(NormalInvChiSquaredError::KTooLow { k })
        } else {
            self.set_k_unchecked(k);
            Ok(())
        }
    }

    /// Set the value of k without input validation
    #[inline]
    pub fn set_k_unchecked(&mut self, k: f64) {
        self.k = k;
    }

    /// Get the v parameter
    #[inline]
    pub fn v(&self) -> f64 {
        self.v
    }

    /// Set the value of v
    ///
    /// # Example
    ///
    /// ```rust
    /// use rv::dist::NormalInvChiSquared;
    ///
    /// let mut nix = NormalInvChiSquared::new(0.0, 1.2, 2.3, 3.4).unwrap();
    /// assert_eq!(nix.v(), 2.3);
    ///
    /// nix.set_v(4.3).unwrap();
    /// assert_eq!(nix.v(), 4.3);
    /// ```
    ///
    /// Will error for invalid values
    ///
    /// ```rust
    /// # use rv::dist::NormalInvChiSquared;
    /// # let mut nix = NormalInvChiSquared::new(0.0, 1.2, 2.3, 3.4).unwrap();
    /// assert!(nix.set_v(2.1).is_ok());
    ///
    /// // must be greater than zero
    /// assert!(nix.set_v(0.0).is_err());
    /// assert!(nix.set_v(-1.0).is_err());
    ///
    ///
    /// assert!(nix.set_v(std::f64::INFINITY).is_err());
    /// assert!(nix.set_v(std::f64::NEG_INFINITY).is_err());
    /// assert!(nix.set_v(std::f64::NAN).is_err());
    /// ```
    #[inline]
    pub fn set_v(&mut self, v: f64) -> Result<(), NormalInvChiSquaredError> {
        if !v.is_finite() {
            Err(NormalInvChiSquaredError::VNotFinite { v })
        } else if v <= 0.0 {
            Err(NormalInvChiSquaredError::VTooLow { v })
        } else {
            self.set_v_unchecked(v);
            self.scaled_inv_x2 = OnceCell::new();
            Ok(())
        }
    }

    /// Set the value of v without input validation
    #[inline]
    pub fn set_v_unchecked(&mut self, v: f64) {
        self.v = v;
        self.scaled_inv_x2 = OnceCell::new();
    }

    /// Get the s2 parameter
    #[inline]
    pub fn s2(&self) -> f64 {
        self.s2
    }

    /// Set the value of s2
    ///
    /// # Example
    ///
    /// ```rust
    /// use rv::dist::NormalInvChiSquared;
    ///
    /// let mut nix = NormalInvChiSquared::new(0.0, 1.2, 2.3, 3.4).unwrap();
    /// assert_eq!(nix.s2(), 3.4);
    ///
    /// nix.set_s2(4.3).unwrap();
    /// assert_eq!(nix.s2(), 4.3);
    /// ```
    ///
    /// Will error for invalid values
    ///
    /// ```rust
    /// # use rv::dist::NormalInvChiSquared;
    /// # let mut nix = NormalInvChiSquared::new(0.0, 1.2, 2.3, 3.4).unwrap();
    /// assert!(nix.set_s2(2.1).is_ok());
    ///
    /// // must be greater than zero
    /// assert!(nix.set_s2(0.0).is_err());
    /// assert!(nix.set_s2(-1.0).is_err());
    ///
    ///
    /// assert!(nix.set_s2(std::f64::INFINITY).is_err());
    /// assert!(nix.set_s2(std::f64::NEG_INFINITY).is_err());
    /// assert!(nix.set_s2(std::f64::NAN).is_err());
    /// ```
    #[inline]
    pub fn set_s2(&mut self, s2: f64) -> Result<(), NormalInvChiSquaredError> {
        if !s2.is_finite() {
            Err(NormalInvChiSquaredError::S2NotFinite { s2 })
        } else if s2 <= 0.0 {
            Err(NormalInvChiSquaredError::S2TooLow { s2 })
        } else {
            self.set_s2_unchecked(s2);
            self.scaled_inv_x2 = OnceCell::new();
            Ok(())
        }
    }

    /// Set the value of s2 without input validation
    #[inline]
    pub fn set_s2_unchecked(&mut self, s2: f64) {
        self.s2 = s2;
        self.scaled_inv_x2 = OnceCell::new();
    }

    #[inline]
    pub fn scaled_inv_x2(&self) -> &ScaledInvChiSquared {
        self.scaled_inv_x2
            .get_or_init(|| ScaledInvChiSquared::new_unchecked(self.v, self.s2))
    }
}

impl From<&NormalInvChiSquared> for String {
    fn from(nix: &NormalInvChiSquared) -> String {
        format!(
            "Normal-Inverse-X²(m: {}, k: {}, v: {}, s2: {})",
            nix.m, nix.k, nix.v, nix.s2
        )
    }
}

impl_display!(NormalInvChiSquared);

impl Rv<Gaussian> for NormalInvChiSquared {
    fn ln_f(&self, x: &Gaussian) -> f64 {
        let lnf_sigma = self.scaled_inv_x2().ln_f(&(x.sigma() * x.sigma()));
        let prior_sigma = x.sigma() / self.k.sqrt();
        let lnf_mu = Gaussian::new_unchecked(self.m, prior_sigma).ln_f(&x.mu());
        lnf_sigma + lnf_mu
    }

    fn draw<R: Rng>(&self, mut rng: &mut R) -> Gaussian {
        let var: f64 = self.scaled_inv_x2().draw(&mut rng);

        let sigma = if var <= 0.0 {
            std::f64::EPSILON
        } else {
            var.sqrt()
        };

        let post_sigma: f64 = sigma / self.k.sqrt();
        let mu: f64 = Gaussian::new(self.m, post_sigma)
            .map_err(|err| {
                panic!("Invalid μ params when drawing Gaussian: {}", err)
            })
            .unwrap()
            .draw(&mut rng);

        Gaussian::new(mu, var.sqrt()).expect("Invalid params")
    }
}

impl std::error::Error for NormalInvChiSquaredError {}

impl std::fmt::Display for NormalInvChiSquaredError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::MNotFinite { m } => write!(f, "non-finite m: {}", m),
            Self::KNotFinite { k } => write!(f, "non-finite k: {}", k),
            Self::VNotFinite { v } => write!(f, "non-finite v: {}", v),
            Self::S2NotFinite { s2 } => write!(f, "non-finite s2: {}", s2),
            Self::KTooLow { k } => {
                write!(f, "k ({}) must be greater than zero", k)
            }
            Self::VTooLow { v } => {
                write!(f, "v ({}) must be greater than zero", v)
            }
            Self::S2TooLow { s2 } => {
                write!(f, "s2 ({}) must be greater than zero", s2)
            }
        }
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use crate::{test_basic_impls, verify_cache_resets};

    test_basic_impls!(
        NormalInvChiSquared::new(0.1, 1.2, 2.3, 3.4).unwrap(),
        Gaussian::new(-1.2, 0.4).unwrap()
    );

    verify_cache_resets!(
        [unchecked],
        ln_f_is_same_after_reset_unchecked_v_identically,
        set_v_unchecked,
        NormalInvChiSquared::new(0.1, 1.2, 2.3, 3.4).unwrap(),
        Gaussian::new(-1.2, 0.4).unwrap(),
        2.3,
        3.14
    );

    verify_cache_resets!(
        [checked],
        ln_f_is_same_after_reset_checked_v_identically,
        set_v,
        NormalInvChiSquared::new(0.1, 1.2, 2.3, 3.4).unwrap(),
        Gaussian::new(-1.2, 0.4).unwrap(),
        2.3,
        3.14
    );

    verify_cache_resets!(
        [unchecked],
        ln_f_is_same_after_reset_unchecked_s2_identically,
        set_s2_unchecked,
        NormalInvChiSquared::new(0.1, 1.2, 2.3, 3.4).unwrap(),
        Gaussian::new(-1.2, 0.4).unwrap(),
        3.4,
        0.8
    );

    verify_cache_resets!(
        [checked],
        ln_f_is_same_after_reset_checked_s2_identically,
        set_s2,
        NormalInvChiSquared::new(0.1, 1.2, 2.3, 3.4).unwrap(),
        Gaussian::new(-1.2, 0.4).unwrap(),
        3.4,
        0.8
    );
}