1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
use crate::{
ConditionDifferentiableDistribution, ContinuousSamplesDistribution, Distribution,
DistributionError, RandomVariable, ValueDifferentiableDistribution,
};
use opensrdk_kernel_method::*;
use opensrdk_linear_algebra::*;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
pub struct SteinVariational<'a, L, P, A, B, K>
where
L: Distribution<Value = A, Condition = B> + ConditionDifferentiableDistribution,
P: Distribution<Value = B, Condition = ()> + ValueDifferentiableDistribution,
A: RandomVariable,
B: RandomVariable,
K: PositiveDefiniteKernel<Vec<f64>> + ValueDifferentiableKernel<Vec<f64>>,
{
value: &'a A,
likelihood: &'a L,
prior: &'a P,
kernel: &'a K,
kernel_params: &'a [f64],
samples: &'a mut ContinuousSamplesDistribution<Vec<f64>>,
}
impl<'a, L, P, A, B, K> SteinVariational<'a, L, P, A, B, K>
where
L: Distribution<Value = A, Condition = B> + ConditionDifferentiableDistribution,
P: Distribution<Value = B, Condition = ()> + ValueDifferentiableDistribution,
A: RandomVariable,
B: RandomVariable,
K: PositiveDefiniteKernel<Vec<f64>> + ValueDifferentiableKernel<Vec<f64>>,
{
pub fn new(
value: &'a A,
likelihood: &'a L,
prior: &'a P,
kernel: &'a K,
kernel_params: &'a [f64],
samples: &'a mut ContinuousSamplesDistribution<Vec<f64>>,
) -> Self {
Self {
value,
likelihood,
prior,
kernel,
kernel_params,
samples,
}
}
pub fn direction(&self, theta: &B) -> Result<Vec<f64>, DistributionError> {
let n = self.samples.samples().len();
let theta_vec = theta.clone().transform_vec().0;
let phi = (0..n)
.into_par_iter()
.map(|j| &self.samples.samples()[j])
.map(|theta_j| {
let kernel = self
.kernel
.value(self.kernel_params, &theta_vec, &theta_j)
.unwrap();
let kernel_diff = self
.kernel
.ln_diff_value(self.kernel_params, &theta_vec, &theta_j)
.unwrap()
.col_mat();
let p_diff = self
.likelihood
.ln_diff_condition(self.value, &theta)
.unwrap()
.col_mat()
+ self.prior.ln_diff_value(&theta, &()).unwrap().col_mat();
kernel * p_diff + kernel_diff
})
.reduce(
|| vec![0.0; theta_vec.len()].col_mat(),
|sum, theta| sum + theta / n as f64,
);
Ok(phi.vec())
}
}