1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
use crate::{Kernel, KernelAdd}; use rayon::prelude::*; use std::marker::PhantomData; use std::{error::Error, ops::Add}; use std::{fmt::Debug, ops::Mul}; #[derive(Clone, Debug)] pub struct KernelMul<L, R, T> where L: Kernel<T>, R: Kernel<T>, T: Clone + Debug, { lhs: L, rhs: R, phantom: PhantomData<T>, } impl<L, R, T> KernelMul<L, R, T> where L: Kernel<T>, R: Kernel<T>, T: Clone + Debug, { pub fn new(lhs: L, rhs: R) -> Self { Self { lhs, rhs, phantom: PhantomData, } } } impl<L, R, T> Kernel<T> for KernelMul<L, R, T> where L: Kernel<T>, R: Kernel<T>, T: Clone + Debug, { fn params_len(&self) -> usize { self.lhs.params_len() + self.rhs.params_len() } fn value( &self, params: &[f64], x: &T, xprime: &T, with_grad: bool, ) -> Result<(f64, Vec<f64>), Box<dyn Error>> { let lhs_params_len = self.lhs.params_len(); let (fx, dfx) = self .lhs .value(¶ms[..lhs_params_len], x, xprime, with_grad)?; let (gx, dgx) = self .rhs .value(¶ms[lhs_params_len..], x, xprime, with_grad)?; let hx = fx * gx; let ghx = if !with_grad { vec![] } else { let ghx = dfx .par_iter() .map(|dfxi| dfxi * gx) .chain(dgx.par_iter().map(|dgxi| fx * dgxi)) .collect::<Vec<_>>(); ghx }; Ok((hx, ghx)) } } impl<Rhs, L, R, T> Add<Rhs> for KernelMul<L, R, T> where Rhs: Kernel<T>, L: Kernel<T>, R: Kernel<T>, T: Clone + Debug, { type Output = KernelAdd<Self, Rhs, T>; fn add(self, rhs: Rhs) -> Self::Output { Self::Output::new(self, rhs) } } impl<Rhs, L, R, T> Mul<Rhs> for KernelMul<L, R, T> where Rhs: Kernel<T>, L: Kernel<T>, R: Kernel<T>, T: Clone + Debug, { type Output = KernelMul<Self, Rhs, T>; fn mul(self, rhs: Rhs) -> Self::Output { Self::Output::new(self, rhs) } }