1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
use super::Kernel; use crate::{KernelAdd, KernelError, KernelMul}; use std::fmt::Debug; use std::{error::Error, ops::Add, ops::Mul}; const PARAMS_LEN: usize = 1; #[derive(Clone, Debug)] pub struct Bias; impl<T> Kernel<T> for Bias where T: Clone + Debug, { fn params_len(&self) -> usize { PARAMS_LEN } fn value( &self, params: &[f64], _: &T, _: &T, _: bool, ) -> Result<(f64, Vec<f64>), Box<dyn Error>> { if params.len() != PARAMS_LEN { return Err(KernelError::ParametersLengthMismatch.into()); } let fx = params[0]; let gfx = vec![1.0]; Ok((fx, gfx)) } } impl<R> Add<R> for Bias where R: Kernel<Vec<f64>>, { type Output = KernelAdd<Self, R, Vec<f64>>; fn add(self, rhs: R) -> Self::Output { Self::Output::new(self, rhs) } } impl<R> Mul<R> for Bias where R: Kernel<Vec<f64>>, { type Output = KernelMul<Self, R, Vec<f64>>; fn mul(self, rhs: R) -> Self::Output { Self::Output::new(self, rhs) } }