1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
//! Distribution over multiple data types
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};

use crate::data::Datum;
use crate::dist::Distribution;
use crate::traits::Rv;

/// A product distribution is the distribution of independent distributions.
///
/// # Notes
///
/// The `ProductDistribution` is an abstraction around `Vec<Dist>`, which allows
/// implementation of `Rv<Vec<Datum>>`.
///
/// # Example
///
/// Create a mixture of product distributions of Categorical * Gaussian
///
/// ```
/// use rv::data::Datum;
/// use rv::dist::{
///     Categorical, Gaussian, Mixture, ProductDistribution, Distribution
/// };
/// use rv::traits::Rv;
///
/// // NOTE: Because the ProductDistribution is an abstraction around Vec<Dist>,
/// // the user must take care to get the order of distributions in each
/// // ProductDistribution correct.
/// let prod_1 = ProductDistribution::new(vec![
///     Distribution::Categorical(Categorical::new(&[0.1, 0.9]).unwrap()),
///     Distribution::Gaussian(Gaussian::new(3.0, 1.0).unwrap()),
/// ]);
///
/// let prod_2 = ProductDistribution::new(vec![
///     Distribution::Categorical(Categorical::new(&[0.9, 0.1]).unwrap()),
///     Distribution::Gaussian(Gaussian::new(-3.0, 1.0).unwrap()),
/// ]);
///
/// let prodmix = Mixture::new(vec![0.5, 0.5], vec![prod_1, prod_2]).unwrap();
///
/// let mut rng = rand::thread_rng();
///
/// let x: Datum = prodmix.draw(&mut rng);
/// let fx = prodmix.f(&x);
///
/// println!("draw: {:?}", x);
/// println!("f(x): {}", fx);
/// ```
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct ProductDistribution {
    dists: Vec<Distribution>,
}

impl ProductDistribution {
    /// Create a new product distribution
    ///
    /// # Example
    ///
    /// ```
    /// use rv::data::Datum;
    /// use rv::dist::{
    ///     Categorical, Gaussian, Mixture, ProductDistribution, Distribution
    /// };
    /// use rv::traits::Rv;
    ///
    /// let prod = ProductDistribution::new(vec![
    ///     Distribution::Categorical(Categorical::new(&[0.1, 0.9]).unwrap()),
    ///     Distribution::Gaussian(Gaussian::new(3.0, 1.0).unwrap()),
    /// ]);
    ///
    /// let mut rng = rand::thread_rng();
    /// let x: Datum = prod.draw(&mut rng);
    /// ```
    pub fn new(dists: Vec<Distribution>) -> Self {
        Self { dists }
    }
}

impl Rv<Vec<Datum>> for ProductDistribution {
    fn ln_f(&self, x: &Vec<Datum>) -> f64 {
        self.dists
            .iter()
            .zip(x.iter())
            .map(|(dist, x_i)| dist.ln_f(x_i))
            .sum()
    }

    fn draw<R: rand::Rng>(&self, rng: &mut R) -> Vec<Datum> {
        self.dists.iter().map(|dist| dist.draw(rng)).collect()
    }
}

impl Rv<Datum> for ProductDistribution {
    fn ln_f(&self, x: &Datum) -> f64 {
        match x {
            Datum::Compound(ref xs) => self.ln_f(xs),
            _ => panic!("unsupported data type for product distribution"),
        }
    }

    fn draw<R: rand::Rng>(&self, rng: &mut R) -> Datum {
        Datum::Compound(self.draw(rng))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::data::Datum;
    use crate::dist::{Categorical, Distribution, Gaussian};

    fn catgauss_mix() -> ProductDistribution {
        ProductDistribution::new(vec![
            Distribution::Categorical(Categorical::new(&[0.1, 0.9]).unwrap()),
            Distribution::Gaussian(Gaussian::standard()),
        ])
    }

    #[test]
    fn ln_f() {
        let gauss = Gaussian::standard();
        let cat = Categorical::new(&[0.1, 0.9]).unwrap();

        let x_cat = 0_u8;
        let x_gauss = 1.2_f64;

        let x_prod =
            Datum::Compound(vec![Datum::U8(x_cat), Datum::F64(x_gauss)]);

        let ln_f = cat.ln_f(&x_cat) + gauss.ln_f(&x_gauss);
        let ln_f_prod = catgauss_mix().ln_f(&x_prod);

        assert::close(ln_f, ln_f_prod, 1e-12);
    }
}