1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
extern crate rayon; extern crate thiserror; pub use ard::ard; pub use bias::bias; pub use exponential::exponential; pub use linear::linear; pub use periodic::periodic; pub use rbf::rbf; use std::error::Error; use std::fmt::Debug; pub mod ard; pub mod bias; pub mod exponential; pub mod linear; pub mod ops; pub mod periodic; pub mod rbf; pub type Func<T> = Box<dyn Fn(&T, &T, bool, &[f64]) -> Result<(f64, Option<Vec<f64>>), Box<dyn Error>>>; pub struct Kernel<T> where T: ?Sized, { params: Vec<f64>, func: Func<T>, } impl<T> Debug for Kernel<T> where T: ?Sized, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:#?}", self.params) } } unsafe impl<T> Send for Kernel<T> where T: ?Sized {} unsafe impl<T> Sync for Kernel<T> where T: ?Sized {} impl<T> Kernel<T> where T: ?Sized, { pub fn new(params: Vec<f64>, func: Func<T>) -> Self { Self { params, func } } pub fn params(&self) -> &[f64] { &self.params } pub fn params_mut(&mut self) -> &mut [f64] { &mut self.params } pub fn func( &self, x: &T, x_prime: &T, with_grad: bool, rewrite_params: Option<&[f64]>, ) -> Result<(f64, Option<Vec<f64>>), Box<dyn Error>> { (self.func)( x, x_prime, with_grad, match rewrite_params { None => &self.params, Some(v) => { if self.params.len() != v.len() { return Err(KernelError::ParametersLengthMismatch.into()); } v } }, ) } } #[derive(thiserror::Error, Debug)] pub enum KernelError { #[error("invalid argument")] InvalidArgument, #[error("invalid parameter")] InvalidParameter, #[error("parameters length mismatch")] ParametersLengthMismatch, } #[cfg(test)] mod tests { use crate::*; #[test] fn it_works() { let kernel = bias() + bias() * linear() + bias() * rbf() + bias() * periodic(); let (func, grad) = kernel .func(&vec![1.0, 2.0, 3.0], &vec![10.0, 20.0, 30.0], true, None) .unwrap(); println!("{}", func); println!("{:#?}", grad); } }