1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
use super::Kernel;
use crate::KernelError;
use rayon::prelude::*;

const PARAMS_LEN: usize = 2;

fn norm(x: &[f64], x_prime: &[f64]) -> f64 {
    x.par_iter()
        .zip(x_prime.par_iter())
        .map(|(x_i, x_prime_i)| (x_i - x_prime_i).powi(2))
        .sum::<f64>()
        .sqrt()
}

pub fn periodic() -> Kernel<[f64]> {
    Kernel::<[f64]>::new(
        vec![1.0; PARAMS_LEN],
        Box::new(
            |x: &[f64], x_prime: &[f64], with_grad: bool, params: &[f64]| {
                if x.len() != x_prime.len() {
                    return Err(KernelError::InvalidArgument.into());
                }

                let norm = norm(x, x_prime);

                let func = (params[0] * (norm / params[1]).cos()).exp();

                let grad = if !with_grad {
                    None
                } else {
                    let mut grad = vec![f64::default(); PARAMS_LEN];

                    grad[0] =
                        (params[0] * (norm / params[1]).cos()).exp() * (norm / params[1]).cos();
                    grad[1] = (params[0] * (norm / params[1]).cos()).exp()
                        * params[0]
                        * (norm / params[1]).sin()
                        * (norm / params[1].powi(2));

                    Some(grad)
                };

                Ok((func, grad))
            },
        ),
    )
}