1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
use crate::{Kernel, KernelMul}; use std::fmt::Debug; use std::marker::PhantomData; use std::{error::Error, ops::Add, ops::Mul}; #[derive(Clone, Debug)] pub struct KernelAdd<L, R, T> where L: Kernel<T>, R: Kernel<T>, T: Clone + Debug, { lhs: L, rhs: R, phantom: PhantomData<T>, } impl<L, R, T> KernelAdd<L, R, T> where L: Kernel<T>, R: Kernel<T>, T: Clone + Debug, { pub fn new(lhs: L, rhs: R) -> Self { Self { lhs, rhs, phantom: PhantomData, } } } impl<L, R, T> Kernel<T> for KernelAdd<L, R, T> where L: Kernel<T>, R: Kernel<T>, T: Clone + Debug, { fn params_len(&self) -> usize { self.lhs.params_len() + self.rhs.params_len() } fn value( &self, params: &[f64], x: &T, xprime: &T, with_grad: bool, ) -> Result<(f64, Vec<f64>), Box<dyn Error>> { let lhs_params_len = self.lhs.params_len(); let (fx, gfx) = self .lhs .value(¶ms[..lhs_params_len], x, xprime, with_grad)?; let (gx, ggx) = self .rhs .value(¶ms[lhs_params_len..], x, xprime, with_grad)?; let hx = fx + gx; let ghx = if !with_grad { vec![] } else { let ghx = [gfx, ggx].concat(); ghx }; Ok((hx, ghx)) } } impl<Rhs, L, R, T> Add<Rhs> for KernelAdd<L, R, T> where Rhs: Kernel<T>, L: Kernel<T>, R: Kernel<T>, T: Clone + Debug, { type Output = KernelAdd<Self, Rhs, T>; fn add(self, rhs: Rhs) -> Self::Output { Self::Output::new(self, rhs) } } impl<Rhs, L, R, T> Mul<Rhs> for KernelAdd<L, R, T> where Rhs: Kernel<T>, L: Kernel<T>, R: Kernel<T>, T: Clone + Debug, { type Output = KernelMul<Self, Rhs, T>; fn mul(self, rhs: Rhs) -> Self::Output { Self::Output::new(self, rhs) } }