1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
use super::Kernel; use crate::KernelError; use rayon::prelude::*; const PARAMS_LEN: usize = 1; fn norm_pow(x: &[f64], x_prime: &[f64]) -> f64 { x.par_iter() .zip(x_prime.par_iter()) .map(|(x_i, x_prime_i)| (x_i - x_prime_i).powi(2)) .sum() } pub fn rbf() -> Kernel<[f64]> { Kernel::<[f64]>::new( vec![1000.0; PARAMS_LEN], Box::new( |x: &[f64], x_prime: &[f64], with_grad: bool, params: &[f64]| { if x.len() != x_prime.len() { return Err(KernelError::InvalidArgument.into()); } let norm_pow = norm_pow(x, x_prime); let func = (-norm_pow / params[0]).exp(); let grad = if !with_grad { None } else { let mut grad = vec![f64::default(); PARAMS_LEN]; grad[0] = (-norm_pow / params[0]).exp() * (norm_pow / params[0].powi(2)); Some(grad) }; Ok((func, grad)) }, ), ) } #[cfg(test)] mod tests { use crate::*; #[test] fn it_works() { let kernel = rbf(); let (func, grad) = kernel .func(&vec![1.0, 2.0, 3.0], &vec![10.0, 20.0, 30.0], true, None) .unwrap(); println!("{}", func); println!("{:#?}", grad); } }