use crate::Distribution;
fn rbf_kernel(x: &[f64], y: &[f64], bandwidth: f64) -> f64 {
let mut sq_dist = 0.0;
for i in 0..x.len() {
let diff = x[i] - y[i];
sq_dist += diff * diff;
}
(-sq_dist / (2.0 * bandwidth * bandwidth)).exp()
}
fn rbf_kernel_grad(x: &[f64], y: &[f64], bandwidth: f64) -> Vec<f64> {
let k = rbf_kernel(x, y, bandwidth);
let h2 = bandwidth * bandwidth;
x.iter()
.zip(y)
.map(|(xi, yi)| -(xi - yi) * k / h2)
.collect()
}
fn median_bandwidth(particles: &[Vec<f64>]) -> f64 {
let n = particles.len();
if n <= 1 {
return 1.0;
}
let mut distances = Vec::new();
for i in 0..n {
for j in (i + 1)..n {
let mut sq_dist = 0.0;
for (pi, pj) in particles[i].iter().zip(&particles[j]) {
let diff = pi - pj;
sq_dist += diff * diff;
}
distances.push(sq_dist.sqrt());
}
}
distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
distances[distances.len() / 2]
}
pub fn svgd_step(
particles: &mut Distribution<Vec<f64>>,
log_prob_grads: &[Vec<f64>],
step_size: f64,
bandwidth: Option<f64>,
) {
let n = particles.samples.len();
if n == 0 {
return;
}
let d = particles.samples[0].len();
assert_eq!(log_prob_grads.len(), n);
let h = bandwidth.unwrap_or_else(|| median_bandwidth(&particles.samples));
let mut directions = vec![vec![0.0; d]; n];
for (i, direction) in directions.iter_mut().enumerate().take(n) {
for (j, log_prob_grad) in log_prob_grads.iter().enumerate().take(n) {
let k_ij = rbf_kernel(&particles.samples[i], &particles.samples[j], h);
let grad_k = rbf_kernel_grad(&particles.samples[i], &particles.samples[j], h);
for (dim, dir_val) in direction.iter_mut().enumerate().take(d) {
*dir_val += k_ij * log_prob_grad[dim] + grad_k[dim];
}
}
for dir_val in direction.iter_mut().take(d) {
*dir_val /= n as f64;
}
}
for (particle, direction) in particles.samples.iter_mut().zip(&directions) {
for (p_val, &d_val) in particle.iter_mut().zip(direction) {
*p_val += step_size * d_val;
}
}
}
pub fn svgd_inference_step<F, P>(
particles: &mut Distribution<Vec<f64>>,
simulator: F,
observations: &[f64],
noise_std: f64,
step_size: f64,
prior_grad: Option<P>,
) where
F: Fn(&[f64]) -> Vec<f64>,
P: Fn(&[f64]) -> Vec<f64>,
{
let n = particles.samples.len();
let mut log_prob_grads = Vec::with_capacity(n);
let sigma2 = noise_std * noise_std;
for particle in &particles.samples {
let pred = simulator(particle);
let eps = 1e-6;
let mut grad = vec![0.0; particle.len()];
for dim in 0..particle.len() {
let mut particle_plus = particle.clone();
particle_plus[dim] += eps;
let pred_plus = simulator(&particle_plus);
let mut grad_dim = 0.0;
for j in 0..pred.len() {
let residual = observations[j] - pred[j];
let d_pred = (pred_plus[j] - pred[j]) / eps;
grad_dim += d_pred * residual / sigma2;
}
grad[dim] = grad_dim;
}
if let Some(ref prior_fn) = prior_grad {
let prior_g = prior_fn(particle);
for dim in 0..grad.len() {
grad[dim] += prior_g[dim];
}
}
log_prob_grads.push(grad);
}
svgd_step(particles, &log_prob_grads, step_size, None);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rbf_kernel() {
let x = vec![0.0, 0.0];
let y = vec![1.0, 0.0];
let k = rbf_kernel(&x, &y, 1.0);
assert!((k - (-0.5_f64).exp()).abs() < 1e-6);
let k_self = rbf_kernel(&x, &x, 1.0);
assert!((k_self - 1.0).abs() < 1e-10);
}
#[test]
fn test_rbf_kernel_grad() {
let x = vec![1.0, 0.0];
let y = vec![0.0, 0.0];
let grad = rbf_kernel_grad(&x, &y, 1.0);
assert!(grad[0] < 0.0);
assert!(grad[1].abs() < 1e-10);
}
#[test]
fn test_median_bandwidth() {
let particles = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
];
let h = median_bandwidth(&particles);
assert!((h - 1.0).abs() < 0.1);
}
#[test]
fn test_svgd_step_basic() {
let mut particles =
Distribution::uniform(vec![vec![-2.0], vec![-1.0], vec![1.0], vec![2.0]]);
let log_prob_grads = vec![vec![2.0], vec![1.0], vec![-1.0], vec![-2.0]];
svgd_step(&mut particles, &log_prob_grads, 0.1, Some(1.0));
let mean_after: f64 = particles.samples.iter().map(|p| p[0]).sum::<f64>() / 4.0;
assert!(mean_after.abs() < 1.0);
}
#[test]
fn test_svgd_inference_simple() {
let mut particles = Distribution::uniform(vec![vec![1.0], vec![2.0], vec![3.0]]);
let observations = vec![2.0, 4.0];
let simulator = |params: &[f64]| -> Vec<f64> {
let a = params[0];
vec![a * 1.0, a * 2.0]
};
for _ in 0..5 {
svgd_inference_step(
&mut particles,
&simulator,
&observations,
0.1,
0.01,
None::<fn(&[f64]) -> Vec<f64>>,
);
}
for particle in &particles.samples {
assert!(particle[0].is_finite());
}
}
}