1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
use super::Kernel;
use crate::KernelError;
use rayon::prelude::*;

pub fn linear() -> Kernel<[f64]> {
    Kernel::<[f64]>::new(
        vec![],
        Box::new(|x: &[f64], x_prime: &[f64], with_grad: bool, _: &[f64]| {
            if x.len() != x_prime.len() {
                return Err(KernelError::InvalidArgument.into());
            }

            let func = x
                .par_iter()
                .zip(x_prime.par_iter())
                .map(|(x_i, x_prime_i)| x_i * x_prime_i)
                .sum();

            let grad = if !with_grad { None } else { Some(vec![]) };

            Ok((func, grad))
        }),
    )
}