1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
use crate::{Kernel, KernelError}; use std::{error::Error, fmt::Debug}; pub trait Convolutable: Clone + Debug + Sync + Send { fn parts_len(&self) -> usize; fn part(&self, index: usize) -> &Vec<f64>; fn data_len(&self) -> usize; } impl Convolutable for Vec<f64> { fn parts_len(&self) -> usize { 1 } fn part(&self, _: usize) -> &Vec<f64> { self } fn data_len(&self) -> usize { self.len() } } #[derive(Clone, Debug)] pub struct Convolutional<K> where K: Kernel<Vec<f64>>, { kernel: K, } impl<K> Convolutional<K> where K: Kernel<Vec<f64>>, { pub fn new(kernel: K) -> Self { Self { kernel } } pub fn kernel_ref(&self) -> &K { &self.kernel } } impl<T, K> Kernel<T> for Convolutional<K> where T: Convolutable, K: Kernel<Vec<f64>>, { fn params_len(&self) -> usize { self.kernel.params_len() } fn value( &self, params: &[f64], x: &T, xprime: &T, with_grad: bool, ) -> Result<(f64, Vec<f64>), Box<dyn Error>> { if params.len() != self.kernel.params_len() { return Err(KernelError::ParametersLengthMismatch.into()); } let p = x.parts_len(); if p != xprime.parts_len() { return Err(KernelError::InvalidArgument.into()); } let (fx, gfx): (f64, Vec<f64>) = (0..p) .into_iter() .map(|pi| { self.kernel .value(params, x.part(pi), xprime.part(pi), with_grad) }) .try_fold::<(f64, Vec<f64>), _, Result<(f64, Vec<f64>), Box<dyn Error>>>( (0.0, vec![]), |a, b| { let b = b?; Ok((a.0 + b.0, [a.1, b.1].concat())) }, )?; Ok((fx, gfx)) } }