1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
use super::Kernel;
use crate::KernelError;
use rayon::prelude::*;

const PARAMS_LEN: usize = 1;

fn norm(x: &[f64], x_prime: &[f64]) -> f64 {
    x.par_iter()
        .zip(x_prime.par_iter())
        .map(|(x_i, x_prime_i)| (x_i - x_prime_i).powi(2))
        .sum::<f64>()
        .sqrt()
}

pub fn exponential() -> Kernel<[f64]> {
    Kernel::<[f64]>::new(
        vec![1000.0; PARAMS_LEN],
        Box::new(
            |x: &[f64], x_prime: &[f64], with_grad: bool, params: &[f64]| {
                if x.len() != x_prime.len() {
                    return Err(KernelError::InvalidArgument.into());
                }

                let norm = norm(x, x_prime);

                let func = (-norm / params[0]).exp();

                let grad = if !with_grad {
                    None
                } else {
                    let mut grad = vec![f64::default(); PARAMS_LEN];

                    grad[0] = (-norm / params[0]).exp() / params[0].powi(2);

                    Some(grad)
                };

                Ok((func, grad))
            },
        ),
    )
}