use super::{CovGrad, CovGradError, E2METRIC, Kernel, KernelError, e2_norm};
use nalgebra::base::storage::Storage;
use nalgebra::{DMatrix, DVector, Dim, Matrix, dvector};
use nalgebra::{
Norm,
base::constraint::{SameNumberOfColumns, ShapeConstraint},
};
use std::f64;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
pub struct RationalQuadratic {
scale: f64,
mixture: f64,
}
impl RationalQuadratic {
pub fn new(scale: f64, mixture: f64) -> Result<Self, KernelError> {
if scale <= 0.0 {
Err(KernelError::ParameterOutOfBounds {
name: "scale".to_string(),
given: scale,
bounds: (0.0, f64::INFINITY),
})
} else if mixture <= 0.0 {
Err(KernelError::ParameterOutOfBounds {
name: "mixture".to_string(),
given: mixture,
bounds: (0.0, f64::INFINITY),
})
} else {
Ok(Self { scale, mixture })
}
}
#[must_use]
pub fn new_unchecked(scale: f64, mixture: f64) -> Self {
Self { scale, mixture }
}
}
impl Kernel for RationalQuadratic {
fn n_parameters(&self) -> usize {
2
}
fn covariance<R1, R2, C1, C2, S1, S2>(
&self,
x1: &Matrix<f64, R1, C1, S1>,
x2: &Matrix<f64, R2, C2, S2>,
) -> DMatrix<f64>
where
R1: Dim,
R2: Dim,
C1: Dim,
C2: Dim,
S1: Storage<f64, R1, C1>,
S2: Storage<f64, R2, C2>,
ShapeConstraint: SameNumberOfColumns<C1, C2>,
{
let d = (2.0 * self.scale * self.scale * self.mixture).sqrt();
DMatrix::from_fn(x1.nrows(), x2.nrows(), |i, j| {
let t = e2_norm(&x1.row(i), &x2.row(j), d);
(1.0 + t).powf(-self.mixture)
})
}
fn is_stationary(&self) -> bool {
true
}
fn diag<R, C, S>(&self, x: &Matrix<f64, R, C, S>) -> DVector<f64>
where
R: Dim,
C: Dim,
S: Storage<f64, R, C>,
{
DVector::repeat(x.len(), 1.0)
}
fn parameters(&self) -> DVector<f64> {
dvector![self.scale.ln(), self.mixture.ln()]
}
fn reparameterize(&self, params: &[f64]) -> Result<Self, KernelError> {
match params {
[] => Err(KernelError::MissingParameters(2)),
[_] => Err(KernelError::MissingParameters(1)),
[length_scale, mixture] => {
Self::new(length_scale.exp(), mixture.exp())
}
_ => Err(KernelError::ExtraneousParameters(params.len() - 1)),
}
}
fn covariance_with_gradient<R, C, S>(
&self,
x: &Matrix<f64, R, C, S>,
) -> Result<(DMatrix<f64>, CovGrad), CovGradError>
where
R: Dim,
C: Dim,
S: Storage<f64, R, C>,
{
let n = x.nrows();
let mut cov = DMatrix::zeros(n, n);
let mut grad = CovGrad::zeros(n, 2);
let d = 2.0 * self.mixture * self.scale.powi(2);
for i in 0..n {
for j in 0..i {
let d2 = E2METRIC.metric_distance(&x.row(i), &x.row(j));
let temp = d2 / d;
let base = 1.0 + temp;
let k = base.powf(-self.mixture);
cov[(i, j)] = k;
cov[(j, i)] = k;
let dk_dl = d2 * k / (self.scale.powi(2) * base);
let dk_da = k * base.ln().mul_add(
-self.mixture,
d2 / (2.0 * self.scale.powi(2) * base),
);
grad[(i, j, 0)] = dk_dl;
grad[(j, i, 0)] = dk_dl;
grad[(j, i, 1)] = dk_da;
grad[(i, j, 1)] = dk_da;
}
cov[(i, i)] = 1.0;
}
Ok((cov, grad))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rational_quadratic() -> Result<(), KernelError> {
let kernel = RationalQuadratic::new(3.0, 5.0)?;
assert::close(kernel.parameters()[0], 3.0_f64.ln(), 1E-10);
assert::close(kernel.parameters()[1], 5.0_f64.ln(), 1E-10);
assert!(kernel.parameters().relative_eq(
&dvector![3.0_f64.ln(), 5.0_f64.ln()],
1E-10,
1E-10
));
let x = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
let y = DMatrix::from_row_slice(2, 2, &[5.0, 7.0, 6.0, 8.0]);
let cov = kernel.covariance(&x, &y);
let expected_cov = DMatrix::from_row_slice(
2,
2,
&[
5_904_900_000.0 / 38_579_489_651.0,
5_904_900_000.0 / 78_502_725_751.0,
5_904_900_000.0 / 11_592_740_742.0,
1_889_568.0 / 6_436_343.0,
],
);
assert!(cov.relative_eq(&expected_cov, 1E-8, 1E-8));
let (cov, grad) = kernel.covariance_with_gradient(&x)?;
let expected_cov = DMatrix::from_row_slice(
2,
2,
&[
1.0,
184_528_125.0 / 282_475_249.0,
184_528_125.0 / 282_475_249.0,
1.0,
],
);
let eg_l = 0.533_268_68;
let eg_a = -0.011_514_11;
let expected_grad = CovGrad::from_row_slices(
2,
2,
&[0.0, eg_l, eg_l, 0.0, 0.0, eg_a, eg_a, 0.0],
)
.unwrap();
assert!(cov.relative_eq(&expected_cov, 1E-8, 1E-8));
assert!(grad.relative_eq(&expected_grad, 1E-8, 1E-8));
Ok(())
}
}