1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
//! Distribution over multiple data types
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use crate::data::Datum;
use crate::dist::Distribution;
use crate::traits::Rv;
/// A product distribution is the distribution of independent distributions.
///
/// # Notes
///
/// The `ProductDistribution` is an abstraction around `Vec<Dist>`, which allows
/// implementation of `Rv<Vec<Datum>>`.
///
/// # Example
///
/// Create a mixture of product distributions of Categorical * Gaussian
///
/// ```
/// use rv::data::Datum;
/// use rv::dist::{
/// Categorical, Gaussian, Mixture, ProductDistribution, Distribution
/// };
/// use rv::traits::Rv;
///
/// // NOTE: Because the ProductDistribution is an abstraction around Vec<Dist>,
/// // the user must take care to get the order of distributions in each
/// // ProductDistribution correct.
/// let prod_1 = ProductDistribution::new(vec![
/// Distribution::Categorical(Categorical::new(&[0.1, 0.9]).unwrap()),
/// Distribution::Gaussian(Gaussian::new(3.0, 1.0).unwrap()),
/// ]);
///
/// let prod_2 = ProductDistribution::new(vec![
/// Distribution::Categorical(Categorical::new(&[0.9, 0.1]).unwrap()),
/// Distribution::Gaussian(Gaussian::new(-3.0, 1.0).unwrap()),
/// ]);
///
/// let prodmix = Mixture::new(vec![0.5, 0.5], vec![prod_1, prod_2]).unwrap();
///
/// let mut rng = rand::thread_rng();
///
/// let x: Datum = prodmix.draw(&mut rng);
/// let fx = prodmix.f(&x);
///
/// println!("draw: {:?}", x);
/// println!("f(x): {}", fx);
/// ```
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct ProductDistribution {
dists: Vec<Distribution>,
}
impl ProductDistribution {
/// Create a new product distribution
///
/// # Example
///
/// ```
/// use rv::data::Datum;
/// use rv::dist::{
/// Categorical, Gaussian, Mixture, ProductDistribution, Distribution
/// };
/// use rv::traits::Rv;
///
/// let prod = ProductDistribution::new(vec![
/// Distribution::Categorical(Categorical::new(&[0.1, 0.9]).unwrap()),
/// Distribution::Gaussian(Gaussian::new(3.0, 1.0).unwrap()),
/// ]);
///
/// let mut rng = rand::thread_rng();
/// let x: Datum = prod.draw(&mut rng);
/// ```
pub fn new(dists: Vec<Distribution>) -> Self {
Self { dists }
}
}
impl Rv<Vec<Datum>> for ProductDistribution {
fn ln_f(&self, x: &Vec<Datum>) -> f64 {
self.dists
.iter()
.zip(x.iter())
.map(|(dist, x_i)| dist.ln_f(x_i))
.sum()
}
fn draw<R: rand::Rng>(&self, rng: &mut R) -> Vec<Datum> {
self.dists.iter().map(|dist| dist.draw(rng)).collect()
}
}
impl Rv<Datum> for ProductDistribution {
fn ln_f(&self, x: &Datum) -> f64 {
match x {
Datum::Compound(ref xs) => self.ln_f(xs),
_ => panic!("unsupported data type for product distribution"),
}
}
fn draw<R: rand::Rng>(&self, rng: &mut R) -> Datum {
Datum::Compound(self.draw(rng))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::Datum;
use crate::dist::{Categorical, Distribution, Gaussian};
fn catgauss_mix() -> ProductDistribution {
ProductDistribution::new(vec![
Distribution::Categorical(Categorical::new(&[0.1, 0.9]).unwrap()),
Distribution::Gaussian(Gaussian::standard()),
])
}
#[test]
fn ln_f() {
let gauss = Gaussian::standard();
let cat = Categorical::new(&[0.1, 0.9]).unwrap();
let x_cat = 0_u8;
let x_gauss = 1.2_f64;
let x_prod =
Datum::Compound(vec![Datum::U8(x_cat), Datum::F64(x_gauss)]);
let ln_f = cat.ln_f(&x_cat) + gauss.ln_f(&x_gauss);
let ln_f_prod = catgauss_mix().ln_f(&x_prod);
assert::close(ln_f, ln_f_prod, 1e-12);
}
}