use crate::linalg::{Dot, Matrix, Vector};
pub trait Kernel<T, S> {
fn forward(&self, x: T, y: T) -> S;
}
pub struct RBFKernel {
var: f64,
length_scale: f64,
}
pub type SquaredExponentialKernel = RBFKernel;
impl RBFKernel {
pub fn new(var: f64, length_scale: f64) -> Self {
assert!(var > 0., "output variance must be positive");
assert!(length_scale > 0., "length scale must be positive");
Self { var, length_scale }
}
}
pub struct RationalQuadraticKernel {
var: f64,
alpha: f64,
length_scale: f64,
}
pub type RQKernel = RationalQuadraticKernel;
impl RQKernel {
pub fn new(var: f64, alpha: f64, length_scale: f64) -> Self {
assert!(var > 0., "output variance must be positive");
assert!(alpha > 0., "scale mixture parameter must be positive");
assert!(length_scale > 0., "length scale must be positive");
Self {
var,
alpha,
length_scale,
}
}
}
macro_rules! impl_kernel_f64_for_rbf {
($t1: ty) => {
impl Kernel<$t1, f64> for RBFKernel {
fn forward(&self, x: $t1, y: $t1) -> f64 {
(-(x - y).powi(2) / (2. * self.length_scale.powi(2))).exp() * self.var
}
}
};
}
impl_kernel_f64_for_rbf!(f64);
impl_kernel_f64_for_rbf!(&f64);
macro_rules! impl_kernel_f64_for_rq {
($t1: ty) => {
impl Kernel<$t1, f64> for RationalQuadraticKernel {
fn forward(&self, x: $t1, y: $t1) -> f64 {
(1. + (x - y).powi(2) / (2. * self.alpha * self.length_scale.powi(2)))
.powf(self.alpha)
* self.var
}
}
};
}
impl_kernel_f64_for_rq!(f64);
impl_kernel_f64_for_rq!(&f64);
macro_rules! impl_kernel_vec_for_rbf {
($t1: ty, $t2: ty) => {
impl Kernel<$t1, $t2> for RBFKernel {
fn forward(&self, x: $t1, y: $t1) -> $t2 {
let (x, y) = (x.reshape(-1, 1), y.reshape(-1, 1));
(-(x.powi(2).reshape(-1, 1) + y.powi(2).reshape(1, -1) - 2. * x.dot_t(y))
/ (2. * self.length_scale.powi(2)))
.exp()
* self.var
}
}
};
}
impl_kernel_vec_for_rbf!(Matrix, Matrix);
impl_kernel_vec_for_rbf!(Vector, Matrix);
impl_kernel_vec_for_rbf!(&Matrix, Matrix);
impl_kernel_vec_for_rbf!(&Vector, Matrix);
macro_rules! impl_kernel_vec_for_rq {
($t1: ty, $t2: ty) => {
impl Kernel<$t1, $t2> for RationalQuadraticKernel {
fn forward(&self, x: $t1, y: $t1) -> $t2 {
let (x, y) = (x.reshape(-1, 1), y.reshape(-1, 1));
(1. + (x.powi(2).reshape(-1, 1) + y.powi(2).reshape(1, -1) - 2. * x.dot_t(y))
/ (2. * self.alpha * self.length_scale.powi(2)))
.powf(self.alpha)
* self.var
}
}
};
}
impl_kernel_vec_for_rq!(Matrix, Matrix);
impl_kernel_vec_for_rq!(Vector, Matrix);
impl_kernel_vec_for_rq!(&Matrix, Matrix);
impl_kernel_vec_for_rq!(&Vector, Matrix);