1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
use crate::{
    ConditionDifferentiableDistribution, ContinuousSamplesDistribution, Distribution,
    DistributionError, RandomVariable, ValueDifferentiableDistribution,
};
use opensrdk_kernel_method::*;
use opensrdk_linear_algebra::*;
use rayon::iter::{IntoParallelIterator, ParallelIterator};

/// Adjust samples {b} from posterior p(b|a) with likelihood p(a|b) and prior p(b)
pub struct SteinVariational<'a, L, P, A, B, K>
where
    L: Distribution<Value = A, Condition = B> + ConditionDifferentiableDistribution,
    P: Distribution<Value = B, Condition = ()> + ValueDifferentiableDistribution,
    A: RandomVariable,
    B: RandomVariable,
    K: PositiveDefiniteKernel<Vec<f64>> + ValueDifferentiableKernel<Vec<f64>>,
{
    value: &'a A,
    likelihood: &'a L,
    prior: &'a P,
    kernel: &'a K,
    kernel_params: &'a [f64],
    samples: &'a mut ContinuousSamplesDistribution<Vec<f64>>,
}

impl<'a, L, P, A, B, K> SteinVariational<'a, L, P, A, B, K>
where
    L: Distribution<Value = A, Condition = B> + ConditionDifferentiableDistribution,
    P: Distribution<Value = B, Condition = ()> + ValueDifferentiableDistribution,
    A: RandomVariable,
    B: RandomVariable,
    K: PositiveDefiniteKernel<Vec<f64>> + ValueDifferentiableKernel<Vec<f64>>,
{
    pub fn new(
        value: &'a A,
        likelihood: &'a L,
        prior: &'a P,
        kernel: &'a K,
        kernel_params: &'a [f64],
        samples: &'a mut ContinuousSamplesDistribution<Vec<f64>>,
    ) -> Self {
        Self {
            value,
            likelihood,
            prior,
            kernel,
            kernel_params,
            samples,
        }
    }

    pub fn direction(&self, theta: &B) -> Result<Vec<f64>, DistributionError> {
        let n = self.samples.samples().len();
        let theta_vec = theta.clone().transform_vec().0;
        let phi = (0..n)
            .into_par_iter()
            .map(|j| &self.samples.samples()[j])
            .map(|theta_j| {
                let kernel = self
                    .kernel
                    .value(self.kernel_params, &theta_vec, &theta_j)
                    .unwrap();
                let kernel_diff = self
                    .kernel
                    .ln_diff_value(self.kernel_params, &theta_vec, &theta_j)
                    .unwrap()
                    .col_mat();
                let p_diff = self
                    .likelihood
                    .ln_diff_condition(self.value, &theta)
                    .unwrap()
                    .col_mat()
                    + self.prior.ln_diff_value(&theta, &()).unwrap().col_mat();
                kernel * p_diff + kernel_diff
            })
            .reduce(
                || vec![0.0; theta_vec.len()].col_mat(),
                |sum, theta| sum + theta / n as f64,
            );

        Ok(phi.vec())
    }
}