1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use crate::{Distribution, *};
use rand::distributions::WeightedIndex;
use rand_distr::Distribution as RandDistribution;
use rayon::prelude::*;
use std::{collections::HashSet, hash::Hash, marker::PhantomData};

#[derive(Clone, Debug)]
pub struct DiscretePosterior<L, P, A, B>
where
    L: Distribution<Value = A, Condition = B>,
    P: Distribution<Value = B, Condition = ()>,
    A: RandomVariable,
    B: RandomVariable + Eq + Hash,
{
    likelihood: L,
    prior: P,
    range: HashSet<B>,
    phantom: PhantomData<A>,
}

impl<L, P, A, B> DiscretePosterior<L, P, A, B>
where
    L: Distribution<Value = A, Condition = B>,
    P: Distribution<Value = B, Condition = ()>,
    A: RandomVariable,
    B: RandomVariable + Eq + Hash,
{
    pub fn new(likelihood: L, prior: P, range: HashSet<B>) -> Self {
        Self {
            likelihood,
            prior,
            range,
            phantom: PhantomData,
        }
    }

    fn weighted(&self, theta: &A) -> Result<Vec<(f64, &B)>, DistributionError> {
        let weighted = self
            .range
            .par_iter()
            .map(|u| -> Result<_, DistributionError> {
                Ok((
                    self.likelihood.p_kernel(theta, u)? * self.prior.p_kernel(u, &())?,
                    u,
                ))
            })
            .collect::<Result<Vec<(f64, &B)>, _>>()?;
        Ok(weighted)
    }

    fn index(&self, weighted: &Vec<(f64, &B)>) -> Result<WeightedIndex<f64>, DistributionError> {
        let index = match WeightedIndex::new(weighted.iter().map(|(w, _)| *w)) {
            Ok(v) => v,
            Err(_) => WeightedIndex::new(vec![1.0; weighted.len()]).unwrap(),
        };
        Ok(index)
    }
}

impl<L, P, A, B> Distribution for DiscretePosterior<L, P, A, B>
where
    L: Distribution<Value = A, Condition = B>,
    P: Distribution<Value = B, Condition = ()>,
    A: RandomVariable,
    B: RandomVariable + Eq + Hash,
{
    type Value = B;
    type Condition = A;

    fn p_kernel(&self, x: &Self::Value, theta: &Self::Condition) -> Result<f64, DistributionError> {
        Ok(self.likelihood.p_kernel(theta, x)? * self.prior.p_kernel(x, &())?)
    }
}

impl<L, P, A, B> SampleableDistribution for DiscretePosterior<L, P, A, B>
where
    L: Distribution<Value = A, Condition = B>,
    P: Distribution<Value = B, Condition = ()>,
    A: RandomVariable,
    B: RandomVariable + Eq + Hash,
{
    fn sample(
        &self,
        theta: &Self::Condition,
        rng: &mut dyn rand::RngCore,
    ) -> Result<Self::Value, DistributionError> {
        let weighted = self.weighted(theta)?;

        let index = self.index(&weighted)?.sample(rng);

        Ok(weighted[index].1.clone())
    }
}

#[cfg(test)]
mod tests {
    use std::collections::HashSet;

    use crate::distribution::Distribution;
    use crate::*;

    #[test]
    fn it_works() {
        let range = vec![true, false].into_iter().collect::<HashSet<_>>();
        // let mut range = HashSet::new();
        // range.insert(true);
        // range.insert(false);
        let model = DiscretePosterior::new(
            Normal.condition(|x: &bool| NormalParams::new(if *x { 10.0 } else { 0.0 }, 1.0)),
            Bernoulli.condition(|_x: &()| BernoulliParams::new(0.5)),
            range,
        );

        // println!("{:?}", model.weighted(&1.0).unwrap());
        let true_result = model.p_kernel(&true, &1.0).unwrap();
        let false_result = model.p_kernel(&false, &1.0).unwrap();
        assert!(true_result < false_result);
    }
}