use crate::data::extract_stat_then;
use crate::data::{DataOrSuffStat, ShiftedSuffStat};
use crate::dist::Shifted;
use crate::traits::{
ConjugatePrior, HasDensity, HasSuffStat, Parameterized, Sampleable,
Shiftable,
};
use rand::Rng;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use std::fmt;
use std::marker::PhantomData;
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
pub struct ShiftedPrior<Pr, Fx>
where
Pr: Sampleable<Fx>,
Fx: Shiftable,
{
parent: Pr,
shift: f64,
_phantom: PhantomData<Fx>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ShiftedPriorError {
NonFiniteShift(f64),
}
impl std::error::Error for ShiftedPriorError {}
#[cfg_attr(coverage_nightly, coverage(off))]
impl fmt::Display for ShiftedPriorError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NonFiniteShift(shift) => {
write!(f, "non-finite shift: {shift}")
}
}
}
}
impl<Pr, Fx> ShiftedPrior<Pr, Fx>
where
Pr: Sampleable<Fx>,
Fx: Shiftable,
{
pub fn new(parent: Pr, shift: f64) -> Result<Self, ShiftedPriorError> {
if shift.is_finite() {
Ok(ShiftedPrior {
parent,
shift,
_phantom: PhantomData,
})
} else {
Err(ShiftedPriorError::NonFiniteShift(shift))
}
}
pub fn new_unchecked(parent: Pr, shift: f64) -> Self {
ShiftedPrior {
parent,
shift,
_phantom: PhantomData,
}
}
pub fn parent(&self) -> &Pr {
&self.parent
}
pub fn parent_mut(&mut self) -> &mut Pr {
&mut self.parent
}
pub fn shift(&self) -> f64 {
self.shift
}
}
impl<Pr, Fx> Sampleable<Shifted<Fx>> for ShiftedPrior<Pr, Fx>
where
Pr: Sampleable<Fx>,
Fx: Shiftable,
{
fn draw<R: Rng>(&self, rng: &mut R) -> Shifted<Fx> {
let fx = self.parent.draw(rng);
Shifted::new_unchecked(fx, self.shift)
}
}
pub struct ShiftedPriorParameters<Pr: Parameterized> {
parent: Pr::Parameters,
shift: f64,
}
impl<Pr, Fx> Parameterized for ShiftedPrior<Pr, Fx>
where
Pr: Sampleable<Fx> + Parameterized,
Fx: Shiftable,
{
type Parameters = ShiftedPriorParameters<Pr>;
fn emit_params(&self) -> Self::Parameters {
ShiftedPriorParameters {
parent: self.parent.emit_params(),
shift: self.shift,
}
}
fn from_params(params: Self::Parameters) -> Self {
let parent = Pr::from_params(params.parent);
Self::new_unchecked(parent, params.shift)
}
}
impl<Pr, Fx> ConjugatePrior<f64, Shifted<Fx>> for ShiftedPrior<Pr, Fx>
where
Pr: ConjugatePrior<f64, Fx, Posterior = Pr>,
Fx: HasSuffStat<f64> + Shiftable + HasDensity<f64>,
Shifted<Fx>: HasSuffStat<f64, Stat = ShiftedSuffStat<Fx::Stat>>,
{
type Posterior = Self;
type MCache = Pr::MCache;
type PpCache = Pr::PpCache;
fn empty_stat(&self) -> ShiftedSuffStat<Fx::Stat> {
let parent_stat = self.parent.empty_stat();
ShiftedSuffStat::new(parent_stat, self.shift)
}
fn posterior_from_suffstat(
&self,
stat: &ShiftedSuffStat<Fx::Stat>,
) -> Self::Posterior {
ShiftedPrior::new_unchecked(
self.parent.posterior_from_suffstat(stat.parent()),
self.shift,
)
}
fn posterior(
&self,
x: &DataOrSuffStat<f64, Shifted<Fx>>,
) -> Self::Posterior {
extract_stat_then(self, x, |stat| self.posterior_from_suffstat(stat))
}
fn ln_m_cache(&self) -> Self::MCache {
self.parent.ln_m_cache()
}
fn ln_m_with_cache(
&self,
cache: &Self::MCache,
x: &DataOrSuffStat<f64, Shifted<Fx>>,
) -> f64 {
let data: Vec<f64> = match x {
DataOrSuffStat::Data(xs) => {
xs.iter().map(|&x| x - self.shift).collect()
}
DataOrSuffStat::SuffStat(_) => vec![], };
self.parent
.ln_m_with_cache(cache, &DataOrSuffStat::Data(&data))
}
fn ln_pp_cache(
&self,
x: &DataOrSuffStat<f64, Shifted<Fx>>,
) -> Self::PpCache {
extract_stat_then(self, x, |stat| {
self.parent
.ln_pp_cache(&DataOrSuffStat::SuffStat(stat.parent()))
})
}
fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &f64) -> f64 {
let shifted_y = *y - self.shift;
self.parent.ln_pp_with_cache(cache, &shifted_y)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::DataOrSuffStat;
use crate::dist::NormalInvChiSquared;
use rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;
#[test]
fn test_shifted_prior_draw() {
let prior = NormalInvChiSquared::new_unchecked(0.0, 1.0, 2.0, 1.0);
let shifted_prior = ShiftedPrior::new(prior, 2.0).unwrap();
let mut rng = Xoshiro256Plus::seed_from_u64(42);
let dist = shifted_prior.draw(&mut rng);
assert_eq!(dist.shift(), 2.0);
}
#[test]
fn test_shifted_prior_conjugate() {
let prior = NormalInvChiSquared::new_unchecked(0.0, 1.0, 2.0, 1.0);
let shifted_prior = ShiftedPrior::new(prior, 2.0).unwrap();
let stat = shifted_prior.empty_stat();
assert_eq!(stat.shift(), 2.0);
let data: Vec<f64> = Vec::new();
let dos = DataOrSuffStat::Data(&data);
let posterior = shifted_prior.posterior(&dos);
assert_eq!(posterior.shift(), 2.0);
}
#[test]
fn test_shifted_prior_with_data() {
let prior = NormalInvChiSquared::new_unchecked(0.0, 1.0, 2.0, 1.0);
let shifted_prior = ShiftedPrior::new(prior, 2.0).unwrap();
let data = vec![2.0, 4.0, 6.0];
let dos = DataOrSuffStat::Data(&data);
let posterior = shifted_prior.posterior(&dos);
assert_eq!(posterior.shift(), 2.0);
let ln_m = shifted_prior.ln_m(&dos);
let ln_pp = shifted_prior.ln_pp(&2.0, &dos);
assert!(ln_m.is_finite());
assert!(ln_pp.is_finite());
}
}