1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
use super::Kernel;
use crate::KernelError;
use crate::Value;
use crate::{KernelAdd, KernelMul};
use std::{fmt::Debug, ops::Add, ops::Mul};

#[derive(Clone)]
pub struct InstantKernel<'a, T>
where
  T: Value,
{
  params_len: usize,
  value: &'a (dyn Fn(&[f64], &T, &T) -> Result<f64, KernelError> + Send + Sync),
  value_with_grad:
    &'a (dyn Fn(&[f64], &T, &T) -> Result<(f64, Vec<f64>), KernelError> + Send + Sync),
}

impl<'a, T> InstantKernel<'a, T>
where
  T: Value,
{
  pub fn new(
    params_len: usize,
    value: &'a (dyn Fn(&[f64], &T, &T) -> Result<f64, KernelError> + Send + Sync),
    value_with_grad: &'a (dyn Fn(&[f64], &T, &T) -> Result<(f64, Vec<f64>), KernelError>
           + Send
           + Sync),
  ) -> Self {
    Self {
      params_len,
      value,
      value_with_grad,
    }
  }
}

impl<'a, T> Debug for InstantKernel<'a, T>
where
  T: Value,
{
  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    write!(f, "InstantKernel {{ params_len: {} }}", self.params_len)
  }
}

impl<'a, T> Kernel<T> for InstantKernel<'a, T>
where
  T: Value,
{
  fn params_len(&self) -> usize {
    self.params_len
  }

  fn value(&self, params: &[f64], x: &T, xprime: &T) -> Result<f64, KernelError> {
    (self.value)(params, x, xprime)
  }

  fn value_with_grad(
    &self,
    params: &[f64],
    x: &T,
    xprime: &T,
  ) -> Result<(f64, Vec<f64>), KernelError> {
    (self.value_with_grad)(params, x, xprime)
  }
}

impl<'a, T, R> Add<R> for InstantKernel<'a, T>
where
  T: Value,
  R: Kernel<T>,
{
  type Output = KernelAdd<Self, R, T>;

  fn add(self, rhs: R) -> Self::Output {
    Self::Output::new(self, rhs)
  }
}

impl<'a, T, R> Mul<R> for InstantKernel<'a, T>
where
  T: Value,
  R: Kernel<T>,
{
  type Output = KernelMul<Self, R, T>;

  fn mul(self, rhs: R) -> Self::Output {
    Self::Output::new(self, rhs)
  }
}

#[cfg(test)]
mod tests {
  use crate::*;
  #[test]
  fn it_works() {
    let kernel = RBF + InstantKernel::new(0, &|_, _, _| Ok(0.0), &|_, _, _| Ok((0.0, vec![])));

    let (func, grad) = kernel
      .value_with_grad(&[1.0, 1.0], &vec![1.0, 2.0, 3.0], &vec![3.0, 2.0, 1.0])
      .unwrap();

    println!("{}", func);
    println!("{:#?}", grad);
  }
}