1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
use super::Kernel; use crate::{KernelAdd, KernelError, KernelMul}; use rayon::prelude::*; use std::{ops::Add, ops::Mul}; const PARAMS_LEN: usize = 2; #[derive(Clone, Debug)] pub struct RBF; impl RBF { fn norm_pow(&self, params: &[f64], x: &Vec<f64>, xprime: &Vec<f64>) -> Result<f64, KernelError> { if params.len() != PARAMS_LEN { return Err(KernelError::ParametersLengthMismatch.into()); } if x.len() != xprime.len() { return Err(KernelError::InvalidArgument.into()); } let norm_pow = x .par_iter() .zip(xprime.par_iter()) .map(|(x_i, xprime_i)| (x_i - xprime_i).powi(2)) .sum(); Ok(norm_pow) } } impl Kernel<Vec<f64>> for RBF { fn params_len(&self) -> usize { PARAMS_LEN } fn value(&self, params: &[f64], x: &Vec<f64>, xprime: &Vec<f64>) -> Result<f64, KernelError> { let norm_pow = self.norm_pow(params, x, xprime)?; let fx = params[0] * (-norm_pow / params[1]).exp(); Ok(fx) } fn value_with_grad( &self, params: &[f64], x: &Vec<f64>, xprime: &Vec<f64>, ) -> Result<(f64, Vec<f64>), KernelError> { let norm_pow = self.norm_pow(params, x, xprime)?; let fx = params[0] * (-norm_pow / params[1]).exp(); let gfx = vec![ (-norm_pow / params[1]).exp(), params[0] * (-norm_pow / params[1]).exp() * (norm_pow / params[1].powi(2)), ]; Ok((fx, gfx)) } } impl<R> Add<R> for RBF where R: Kernel<Vec<f64>>, { type Output = KernelAdd<Self, R, Vec<f64>>; fn add(self, rhs: R) -> Self::Output { Self::Output::new(self, rhs) } } impl<R> Mul<R> for RBF where R: Kernel<Vec<f64>>, { type Output = KernelMul<Self, R, Vec<f64>>; fn mul(self, rhs: R) -> Self::Output { Self::Output::new(self, rhs) } } #[cfg(test)] mod tests { use crate::*; #[test] fn it_works() { let kernel = RBF; let (func, grad) = kernel .value_with_grad(&[1.0, 1.0], &vec![1.0, 2.0, 3.0], &vec![3.0, 2.0, 1.0]) .unwrap(); println!("{}", func); println!("{:#?}", grad); } }