1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
use super::Kernel; use crate::KernelError; use crate::Value; use crate::{KernelAdd, KernelMul}; use std::{fmt::Debug, ops::Add, ops::Mul}; #[derive(Clone)] pub struct InstantKernel<'a, T> where T: Value, { params_len: usize, value: &'a (dyn Fn(&[f64], &T, &T) -> Result<f64, KernelError> + Send + Sync), value_with_grad: &'a (dyn Fn(&[f64], &T, &T) -> Result<(f64, Vec<f64>), KernelError> + Send + Sync), } impl<'a, T> InstantKernel<'a, T> where T: Value, { pub fn new( params_len: usize, value: &'a (dyn Fn(&[f64], &T, &T) -> Result<f64, KernelError> + Send + Sync), value_with_grad: &'a (dyn Fn(&[f64], &T, &T) -> Result<(f64, Vec<f64>), KernelError> + Send + Sync), ) -> Self { Self { params_len, value, value_with_grad, } } } impl<'a, T> Debug for InstantKernel<'a, T> where T: Value, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "InstantKernel {{ params_len: {} }}", self.params_len) } } impl<'a, T> Kernel<T> for InstantKernel<'a, T> where T: Value, { fn params_len(&self) -> usize { self.params_len } fn value(&self, params: &[f64], x: &T, xprime: &T) -> Result<f64, KernelError> { (self.value)(params, x, xprime) } fn value_with_grad( &self, params: &[f64], x: &T, xprime: &T, ) -> Result<(f64, Vec<f64>), KernelError> { (self.value_with_grad)(params, x, xprime) } } impl<'a, T, R> Add<R> for InstantKernel<'a, T> where T: Value, R: Kernel<T>, { type Output = KernelAdd<Self, R, T>; fn add(self, rhs: R) -> Self::Output { Self::Output::new(self, rhs) } } impl<'a, T, R> Mul<R> for InstantKernel<'a, T> where T: Value, R: Kernel<T>, { type Output = KernelMul<Self, R, T>; fn mul(self, rhs: R) -> Self::Output { Self::Output::new(self, rhs) } } #[cfg(test)] mod tests { use crate::*; #[test] fn it_works() { let kernel = RBF + InstantKernel::new(0, &|_, _, _| Ok(0.0), &|_, _, _| Ok((0.0, vec![]))); let (func, grad) = kernel .value_with_grad(&[1.0, 1.0], &vec![1.0, 2.0, 3.0], &vec![3.0, 2.0, 1.0]) .unwrap(); println!("{}", func); println!("{:#?}", grad); } }