1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
use crate::Value;
use crate::{Kernel, KernelError};
use rayon::prelude::*;
use std::fmt::Debug;

pub trait Convolutable: Value {
  fn parts_len(&self) -> usize;
  fn part(&self, index: usize) -> &Vec<f64>;
  fn data_len(&self) -> usize;
}

impl Convolutable for Vec<f64> {
  fn parts_len(&self) -> usize {
    1
  }

  fn part(&self, _: usize) -> &Vec<f64> {
    self
  }

  fn data_len(&self) -> usize {
    self.len()
  }
}

#[derive(Clone, Debug)]
pub struct Convolutional<K>
where
  K: Kernel<Vec<f64>>,
{
  kernel: K,
}

impl<K> Convolutional<K>
where
  K: Kernel<Vec<f64>>,
{
  pub fn new(kernel: K) -> Self {
    Self { kernel }
  }

  pub fn kernel_ref(&self) -> &K {
    &self.kernel
  }
}

impl<T, K> Kernel<T> for Convolutional<K>
where
  T: Convolutable,
  K: Kernel<Vec<f64>>,
{
  fn params_len(&self) -> usize {
    self.kernel.params_len()
  }

  fn value(&self, params: &[f64], x: &T, xprime: &T) -> Result<f64, KernelError> {
    if params.len() != self.kernel.params_len() {
      return Err(KernelError::ParametersLengthMismatch.into());
    }
    let p = x.parts_len();
    if p != xprime.parts_len() {
      return Err(KernelError::InvalidArgument.into());
    }

    let fx = (0..p)
      .into_par_iter()
      .map(|pi| self.kernel.value(params, x.part(pi), xprime.part(pi)))
      .sum::<Result<f64, KernelError>>()?;

    Ok(fx)
  }

  fn value_with_grad(
    &self,
    params: &[f64],
    x: &T,
    xprime: &T,
  ) -> Result<(f64, Vec<f64>), KernelError> {
    if params.len() != self.kernel.params_len() {
      return Err(KernelError::ParametersLengthMismatch.into());
    }
    let p = x.parts_len();
    if p != xprime.parts_len() {
      return Err(KernelError::InvalidArgument.into());
    }

    let (fx, gfx): (f64, Vec<f64>) = (0..p)
      .into_iter()
      .map(|pi| {
        self
          .kernel
          .value_with_grad(params, x.part(pi), xprime.part(pi))
      })
      .try_fold::<(f64, Vec<f64>), _, Result<(f64, Vec<f64>), KernelError>>(
        (0.0, vec![]),
        |a, b| {
          let b = b?;
          Ok((a.0 + b.0, [a.1, b.1].concat()))
        },
      )?;

    Ok((fx, gfx))
  }
}