rv/process/gaussian/kernel/
constant_kernel.rs

1use super::{CovGrad, CovGradError, Kernel, KernelError};
2use nalgebra::base::constraint::{SameNumberOfColumns, ShapeConstraint};
3use nalgebra::base::storage::Storage;
4use nalgebra::{dvector, DMatrix, DVector, Dim, Matrix};
5use std::f64;
6
7#[cfg(feature = "serde1")]
8use serde::{Deserialize, Serialize};
9
10#[derive(Clone, Debug, PartialEq)]
11#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
12#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
13pub struct ConstantKernel {
14    scale: f64,
15}
16
17impl ConstantKernel {
18    /// Create a new kernel with the given constant value
19    pub fn new(value: f64) -> Result<Self, KernelError> {
20        if value <= 0.0 {
21            Err(KernelError::ParameterOutOfBounds {
22                name: "value".to_string(),
23                given: value,
24                bounds: (0.0, f64::INFINITY),
25            })
26        } else {
27            Ok(Self { scale: value })
28        }
29    }
30
31    /// Create a new constant function kernel without checking the parameters
32    pub fn new_unchecked(scale: f64) -> Self {
33        Self { scale }
34    }
35}
36
37impl Default for ConstantKernel {
38    fn default() -> Self {
39        Self { scale: 1.0 }
40    }
41}
42
43impl std::convert::TryFrom<f64> for ConstantKernel {
44    type Error = KernelError;
45
46    fn try_from(value: f64) -> Result<Self, Self::Error> {
47        Self::new(value)
48    }
49}
50
51impl Kernel for ConstantKernel {
52    fn n_parameters(&self) -> usize {
53        1
54    }
55
56    fn covariance<R1, R2, C1, C2, S1, S2>(
57        &self,
58        x1: &Matrix<f64, R1, C1, S1>,
59        x2: &Matrix<f64, R2, C2, S2>,
60    ) -> DMatrix<f64>
61    where
62        R1: Dim,
63        R2: Dim,
64        C1: Dim,
65        C2: Dim,
66        S1: Storage<f64, R1, C1>,
67        S2: Storage<f64, R2, C2>,
68        ShapeConstraint: SameNumberOfColumns<C1, C2>,
69    {
70        DMatrix::from_element(x1.nrows(), x2.nrows(), self.scale)
71    }
72
73    fn is_stationary(&self) -> bool {
74        true
75    }
76
77    fn diag<R, C, S>(&self, x: &Matrix<f64, R, C, S>) -> DVector<f64>
78    where
79        R: Dim,
80        C: Dim,
81        S: Storage<f64, R, C>,
82    {
83        DVector::from_element(x.nrows(), self.scale)
84    }
85
86    fn parameters(&self) -> DVector<f64> {
87        dvector![self.scale.ln()]
88    }
89
90    fn reparameterize(&self, param_vec: &[f64]) -> Result<Self, KernelError> {
91        match param_vec {
92            [] => Err(KernelError::MissingParameters(1)),
93            [value] => Self::new(value.exp()),
94            _ => Err(KernelError::ExtraneousParameters(param_vec.len() - 1)),
95        }
96    }
97
98    fn covariance_with_gradient<R, C, S>(
99        &self,
100        x: &Matrix<f64, R, C, S>,
101    ) -> Result<(DMatrix<f64>, CovGrad), CovGradError>
102    where
103        R: Dim,
104        C: Dim,
105        S: Storage<f64, R, C>,
106    {
107        let cov = self.covariance(x, x);
108        let grad = CovGrad::new_unchecked(&[DMatrix::from_element(
109            x.nrows(),
110            x.nrows(),
111            self.scale,
112        )]);
113        Ok((cov, grad))
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn constant_kernel() {
123        let kernel = ConstantKernel::new(3.0).unwrap();
124        assert::close(kernel.parameters()[0], 3.0_f64.ln(), 1E-10);
125        assert!(kernel.parameters().relative_eq(
126            &dvector![3.0_f64.ln()],
127            1E-8,
128            1E-8,
129        ));
130
131        let x = DMatrix::from_column_slice(2, 2, &[1.0, 3.0, 2.0, 4.0]);
132        let y = DMatrix::from_column_slice(2, 2, &[5.0, 6.0, 7.0, 8.0]);
133
134        let (cov, grad) = kernel.covariance_with_gradient(&x).unwrap();
135
136        let expected_cov = DMatrix::from_row_slice(2, 2, &[3.0, 3.0, 3.0, 3.0]);
137
138        let expected_grad =
139            CovGrad::from_row_slices(2, 1, &[3.0, 3.0, 3.0, 3.0]).unwrap();
140
141        assert!(cov.relative_eq(&expected_cov, 1E-8, 1E-8));
142        assert!(grad.relative_eq(&expected_grad, 1E-8, 1E-8));
143
144        let cov = kernel.covariance(&x, &y);
145        let expected_cov = DMatrix::from_row_slice(2, 2, &[3.0, 3.0, 3.0, 3.0]);
146
147        assert!(cov.relative_eq(&expected_cov, 1E-8, 1E-8));
148    }
149}