1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
use crate::Kernel; use rayon::prelude::*; use std::fmt::Debug; #[derive(Clone, Debug)] pub struct ARD { data_len: usize, params: Vec<f64>, } impl ARD { pub fn new(data_len: usize, params: Vec<f64>) -> Self { Self { data_len, params } } fn weighted_norm(x: &Vec<f64>, x_prime: &Vec<f64>, params: &[f64]) -> f64 { let relevances = ¶ms[3..]; x.par_iter() .zip(x_prime.par_iter()) .zip(relevances.par_iter()) .map(|((x_i, x_prime_i), relevance)| relevance * (x_i - x_prime_i).powi(2)) .sum() } fn prod(x: &Vec<f64>, x_prime: &Vec<f64>) -> f64 { x.par_iter() .zip(x_prime.par_iter()) .map(|(x_i, x_prime_i)| x_i * x_prime_i) .sum() } } impl Kernel<Vec<f64>> for ARD { fn get_params(&self) -> &[f64] { &self.params } fn set_params(&mut self, params: &[f64]) -> Result<(), String> { let n = params.len(); if n != 3 + self.data_len { return Err("dimension mismatch".to_owned()); } for i in 0..n { self.params[i] = params[i]; } Ok(()) } fn value(&self, x: &Vec<f64>, x_prime: &Vec<f64>) -> Result<f64, String> { if x.len() != x_prime.len() { return Err("dimension mismatch".to_owned()); } Ok( self.params[0] * (-ARD::weighted_norm(x, x_prime, &self.params)).exp() + self.params[1] + self.params[2] * ARD::prod(x, x_prime), ) } fn grad( &self, x: &Vec<f64>, x_prime: &Vec<f64>, ) -> Result<Box<dyn Fn(&[f64]) -> Result<Vec<f64>, String>>, String> { if x.len() != x_prime.len() { return Err("dimension mismatch".to_owned()); } let params_len = self.params.len(); let x = x.to_vec(); let x_prime = x_prime.to_vec(); Ok(Box::new(move |params: &[f64]| { if params.len() != params_len { return Err("dimension mismatch".to_owned()); } let mut grad = vec![f64::default(); params_len]; grad[0] = (-ARD::weighted_norm(&x, &x_prime, params)).exp(); grad[1] = 1.0; grad[2] = ARD::prod(&x, &x_prime); let relevances_grad = &mut grad[3..]; relevances_grad .par_iter_mut() .zip(x.par_iter()) .zip(x_prime.par_iter()) .for_each(|((s, &l), &r)| *s = -(l - r).powi(2)); Ok(grad) })) } }