rv/process/gaussian/
mod.rs

1//! Gaussian Processes
2
3use argmin::solver::{linesearch::MoreThuenteLineSearch, quasinewton::LBFGS};
4use nalgebra::linalg::Cholesky;
5use nalgebra::{DMatrix, DVector, Dyn};
6use rand::Rng;
7#[cfg(feature = "serde1")]
8use serde::{Deserialize, Serialize};
9use std::cell::OnceCell;
10
11use crate::consts::HALF_LN_2PI;
12use crate::dist::MvGaussian;
13use crate::traits::*;
14
15pub mod kernel;
16use kernel::{Kernel, KernelError};
17
18mod noise_model;
19pub use self::noise_model::NoiseModel;
20
21use super::{RandomProcess, RandomProcessMle};
22
23#[inline]
24fn outer_product_self(col: &DVector<f64>) -> DMatrix<f64> {
25    let row = DMatrix::from_row_slice(1, col.nrows(), col.as_slice());
26    col * row
27}
28
29/// Errors from GaussianProcess
30#[derive(Debug)]
31#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
32#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
33pub enum GaussianProcessError {
34    /// The kernel is not returning a positive-definite matrix. Try adding a small, constant noise parameter as y_train_sigma.
35    NotPositiveSemiDefinite,
36    /// Error from the kernel function
37    KernelError(KernelError),
38    /// The given noise model does not match the training data
39    MisshapenNoiseModel(String),
40}
41
42impl std::error::Error for GaussianProcessError {}
43impl std::fmt::Display for GaussianProcessError {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        match self {
46            Self::NotPositiveSemiDefinite => {
47                writeln!(f, "Covariance matrix is not semi-positive definite")
48            }
49            Self::MisshapenNoiseModel(msg) => {
50                writeln!(f, "Noise model error: {}", msg)
51            }
52            Self::KernelError(e) => writeln!(f, "Error from kernel: {}", e),
53        }
54    }
55}
56
57impl From<KernelError> for GaussianProcessError {
58    fn from(e: KernelError) -> Self {
59        Self::KernelError(e)
60    }
61}
62
63#[derive(Clone, Debug)]
64#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
65#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
66pub struct GaussianProcess<K>
67where
68    K: Kernel,
69{
70    /// Cholesky Decomposition of K
71    k_chol: Cholesky<f64, Dyn>,
72    /// Dual coefficients of training data in kernel space.
73    alpha: DVector<f64>,
74    /// Covariance Kernel
75    pub kernel: K,
76    /// x values used in training
77    x_train: DMatrix<f64>,
78    /// y values used in training
79    y_train: DVector<f64>,
80    /// Inverse covariance matrix
81    k_inv: DMatrix<f64>,
82    /// Noise Model
83    pub noise_model: NoiseModel,
84}
85
86impl<K> GaussianProcess<K>
87where
88    K: Kernel,
89{
90    /// Train a Gaussian Process on the given data points
91    ///
92    /// # Arguments
93    /// * `kernel` - Kernel to use to determine covariance
94    /// * `x_train` - Values to use for input into `f`
95    /// * `y_train` - Known values for `f(x)`
96    /// * `noise_model` - Noise model to use for fitting
97    pub fn train(
98        kernel: K,
99        x_train: DMatrix<f64>,
100        y_train: DVector<f64>,
101        noise_model: NoiseModel,
102    ) -> Result<Self, GaussianProcessError> {
103        let k = noise_model
104            .add_noise_to_kernel(&kernel.covariance(&x_train, &x_train))
105            .map_err(GaussianProcessError::MisshapenNoiseModel)?;
106
107        // Decompose K into Cholesky lower lower triangular matrix
108        let k_chol = match Cholesky::new(k) {
109            Some(ch) => Ok(ch),
110            None => Err(GaussianProcessError::NotPositiveSemiDefinite),
111        }?;
112
113        let k_inv = k_chol.inverse();
114        let alpha = k_chol.solve(&y_train);
115
116        Ok(GaussianProcess {
117            k_chol,
118            alpha,
119            kernel,
120            x_train,
121            y_train,
122            k_inv,
123            noise_model,
124        })
125    }
126
127    /// Return the inverse of K.
128    pub fn k_inv(&self) -> &DMatrix<f64> {
129        &self.k_inv
130    }
131
132    /// Return the Cholesky decomposition of K
133    pub fn k_chol(&self) -> &Cholesky<f64, Dyn> {
134        &(self.k_chol)
135    }
136
137    /// Return the kernel being used in this GP
138    pub fn kernel(&self) -> &K {
139        &(self.kernel)
140    }
141}
142
143impl<K> RandomProcess<f64> for GaussianProcess<K>
144where
145    K: Kernel,
146{
147    type Index = DVector<f64>;
148    type SampleFunction = GaussianProcessPrediction<K>;
149    type Error = GaussianProcessError;
150
151    fn sample_function(&self, indices: &[Self::Index]) -> Self::SampleFunction {
152        let n = indices.len();
153        let m = indices.first().map(|i| i.len()).unwrap_or(0);
154
155        let indices: DMatrix<f64> = DMatrix::from_iterator(
156            n,
157            m,
158            indices.iter().flat_map(|i| i.iter().cloned()),
159        );
160        let k_trans = self.kernel.covariance(&indices, &self.x_train);
161        let y_mean = &k_trans * &self.alpha;
162        GaussianProcessPrediction {
163            gp: self.clone(),
164            y_mean,
165            k_trans,
166            xs: indices,
167            cov: OnceCell::new(),
168            dist: OnceCell::new(),
169        }
170    }
171
172    fn ln_m(&self) -> f64 {
173        let k_chol = self.k_chol();
174        let dlog_sum = k_chol.l_dirty().diagonal().map(|x| x.ln()).sum();
175        let n: f64 = self.x_train.nrows() as f64;
176        let alpha = k_chol.solve(&self.y_train);
177        n.mul_add(
178            -HALF_LN_2PI,
179            (-0.5_f64).mul_add(self.y_train.dot(&alpha), -dlog_sum),
180        )
181    }
182
183    fn ln_m_with_params(
184        &self,
185        parameter: &DVector<f64>,
186    ) -> Result<(f64, DVector<f64>), GaussianProcessError> {
187        let kernel = self
188            .kernel
189            .reparameterize(&parameter.iter().copied().collect::<Vec<f64>>())
190            .map_err(GaussianProcessError::KernelError)?;
191
192        // GPML Equation 2.30
193        let (k, k_grad) = kernel
194            .covariance_with_gradient(&self.x_train)
195            .map_err(|e| GaussianProcessError::KernelError(e.into()))?;
196        let k = self.noise_model.add_noise_to_kernel(&k).unwrap(); // if we got here, the noise model will be okay
197
198        let m = k.nrows();
199        // TODO: try to symmetricize the matrix
200        let maybe_k_chol = Cholesky::new(k.clone());
201
202        if maybe_k_chol.is_none() {
203            eprintln!(
204                "failed to find chol of k = {}, with parameters = {:?}",
205                k, parameter
206            );
207        }
208
209        let k_chol = maybe_k_chol
210            .ok_or(GaussianProcessError::NotPositiveSemiDefinite)?;
211        let alpha = k_chol.solve(&self.y_train);
212        let dlog_sum = k_chol.l_dirty().diagonal().map(|x| x.ln()).sum();
213        let n: f64 = self.x_train.nrows() as f64;
214
215        let ln_m = n.mul_add(
216            -HALF_LN_2PI,
217            (-0.5_f64).mul_add(self.y_train.dot(&alpha), -dlog_sum),
218        );
219
220        // GPML Equation 5.9
221        let aat_kinv = &outer_product_self(&alpha) - &k_chol.inverse();
222        let grad_ln_m: Vec<f64> = (0..parameter.len())
223            .map(|i| {
224                let theta_i_grad = &k_grad[i];
225                let mut sum = 0.0;
226                for j in 0..m {
227                    sum += (aat_kinv.row(j) * theta_i_grad.column(j))[0];
228                }
229                0.5 * sum
230            })
231            .collect();
232        let grad_ln_m = DVector::from(grad_ln_m);
233
234        Ok((ln_m, grad_ln_m))
235    }
236
237    fn parameters(&self) -> DVector<f64> {
238        let kernel = self.kernel();
239        kernel.parameters()
240    }
241
242    fn set_parameters(
243        self,
244        parameters: &DVector<f64>,
245    ) -> Result<Self, GaussianProcessError> {
246        let (kernel, leftovers) = self
247            .kernel
248            .consume_parameters(parameters.iter().copied())
249            .map_err(GaussianProcessError::KernelError)?;
250        let leftovers: Vec<f64> = leftovers.collect();
251        if !leftovers.is_empty() {
252            return Err(GaussianProcessError::KernelError(
253                KernelError::ExtraneousParameters(leftovers.len()),
254            ));
255        }
256
257        Self::train(kernel, self.x_train, self.y_train, self.noise_model)
258    }
259}
260
261impl<K> RandomProcessMle<f64> for GaussianProcess<K>
262where
263    K: Kernel,
264{
265    type Solver = LBFGS<
266        MoreThuenteLineSearch<DVector<f64>, DVector<f64>, f64>,
267        DVector<f64>,
268        DVector<f64>,
269        f64,
270    >;
271
272    fn generate_solver() -> Self::Solver {
273        let linesearch = MoreThuenteLineSearch::new();
274        LBFGS::new(linesearch, 10)
275    }
276
277    fn random_params<R: Rng>(&self, rng: &mut R) -> DVector<f64> {
278        let n = self.parameters().len();
279        DVector::from_iterator(n, (0..n).map(|_| rng.gen_range(-5.0..5.0)))
280    }
281}
282
283/// Structure for making GP predictions
284pub struct GaussianProcessPrediction<K>
285where
286    K: Kernel,
287{
288    /// Parent GP
289    gp: GaussianProcess<K>,
290    /// Mean of y values
291    y_mean: DVector<f64>,
292    /// Intermediate matrix
293    k_trans: DMatrix<f64>,
294    /// Values to predict `f(x)` against.
295    xs: DMatrix<f64>,
296    /// Covariance matrix
297    cov: OnceCell<DMatrix<f64>>,
298    /// Output Distribution
299    dist: OnceCell<MvGaussian>,
300}
301
302impl<K> GaussianProcessPrediction<K>
303where
304    K: Kernel,
305{
306    /// Return the covariance of the posterior
307    pub fn cov(&self) -> &DMatrix<f64> {
308        self.cov.get_or_init(|| {
309            let v = self.gp.k_chol().solve(&(self.k_trans.transpose()));
310            let kernel = self.gp.kernel();
311            &kernel.covariance(&self.xs, &self.xs) - &(self.k_trans) * &v
312        })
313    }
314
315    /// Return the standard deviation of posterior.
316    pub fn std(&self) -> DVector<f64> {
317        let kernel = self.gp.kernel();
318        let k_inv = self.gp.k_inv();
319        let k_ti = &(self.k_trans) * k_inv;
320
321        let mut y_var: DVector<f64> = kernel.diag(&self.xs);
322        for i in 0..y_var.nrows() {
323            y_var[i] -= (0..k_inv.ncols())
324                .map(|j| k_ti[(i, j)] * self.k_trans[(i, j)])
325                .sum::<f64>();
326        }
327        y_var.map(|e| e.sqrt())
328    }
329
330    /// Return the MV Gaussian distribution which shows the predicted values
331    pub fn dist(&self) -> &MvGaussian {
332        let mean = self.y_mean.clone();
333        let cov = (self.cov()).clone();
334        self.dist
335            .get_or_init(|| MvGaussian::new_unchecked(mean, cov))
336    }
337
338    /// Draw a single value from the corresponding MV Gaussian
339    pub fn draw<RNG: Rng>(&self, rng: &mut RNG) -> DVector<f64> {
340        self.dist().draw(rng)
341    }
342
343    /// Return a number of samples from the MV Gaussian
344    pub fn sample<R: Rng>(
345        &self,
346        size: usize,
347        rng: &mut R,
348    ) -> Vec<DVector<f64>> {
349        self.dist().sample(size, rng)
350    }
351}
352
353impl<K> HasDensity<DVector<f64>> for GaussianProcessPrediction<K>
354where
355    K: Kernel,
356{
357    fn ln_f(&self, x: &DVector<f64>) -> f64 {
358        self.dist().ln_f(x)
359    }
360}
361
362impl<K> Sampleable<DVector<f64>> for GaussianProcessPrediction<K>
363where
364    K: Kernel,
365{
366    fn draw<R: Rng>(&self, rng: &mut R) -> DVector<f64> {
367        self.dist().draw(rng)
368    }
369}
370
371impl<K> Mean<DVector<f64>> for GaussianProcessPrediction<K>
372where
373    K: Kernel,
374{
375    fn mean(&self) -> Option<DVector<f64>> {
376        Some(self.y_mean.clone())
377    }
378}
379
380impl<K> Variance<DVector<f64>> for GaussianProcessPrediction<K>
381where
382    K: Kernel,
383{
384    fn variance(&self) -> Option<DVector<f64>> {
385        let kernel = self.gp.kernel();
386        let k_inv = self.gp.k_inv();
387        let k_ti = &(self.k_trans) * k_inv;
388
389        let mut y_var: DVector<f64> = kernel.diag(&self.xs);
390        for i in 0..y_var.nrows() {
391            y_var[i] -= (0..k_inv.ncols())
392                .map(|j| k_ti[(i, j)] * self.k_trans[(i, j)])
393                .sum::<f64>();
394        }
395        Some(y_var)
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use self::kernel::{ConstantKernel, ProductKernel, RBFKernel};
402    use super::*;
403    use crate::test::relative_eq;
404    use nalgebra::dvector;
405    use rand::SeedableRng;
406    use rand_xoshiro::Xoshiro256Plus;
407
408    fn arange(start: f64, stop: f64, step_size: f64) -> Vec<f64> {
409        let size = ((stop - start) / step_size).floor() as usize;
410        (0..size)
411            .map(|i| (i as f64).mul_add(step_size, start))
412            .collect()
413    }
414
415    #[test]
416    fn simple() {
417        let x_train: DMatrix<f64> =
418            DMatrix::from_column_slice(5, 1, &[-4.0, -3.0, -2.0, -1.0, 1.0]);
419        let y_train: DVector<f64> = x_train.map(|x| x.sin()).column(0).into();
420
421        let kernel = RBFKernel::default();
422        let gp = GaussianProcess::train(
423            kernel,
424            x_train,
425            y_train,
426            NoiseModel::default(),
427        )
428        .unwrap();
429
430        let xs: Vec<DVector<f64>> = arange(-5.0, 5.0, 1.0)
431            .into_iter()
432            .map(|x| dvector![x])
433            .collect();
434        let pred = gp.sample_function(xs.as_slice());
435
436        let expected_mean: DMatrix<f64> = DMatrix::from_column_slice(
437            10,
438            1,
439            &[
440                0.614_097_52,
441                0.756_802_5,
442                -0.141_120_01,
443                -0.909_297_43,
444                -0.841_470_98,
445                0.085_333_65,
446                0.841_470_98,
447                0.563_985_6,
448                0.127_422_02,
449                0.010_476_83,
450            ],
451        );
452
453        let mean = pred.mean().expect("Should be able to compute the mean");
454        assert!(mean.relative_eq(&expected_mean, 1E-8, 1E-8));
455
456        let expected_cov = DMatrix::from_row_slice(
457            10,
458            10,
459            &[
460                5.096_256_32e-01,
461                0.000_000_00e+00,
462                5.551_115_12e-17,
463                6.765_421_56e-17,
464                3.165_870_34e-17,
465                3.449_672_76e-02,
466                3.520_513_77e-19,
467                -7.750_552_24e-03,
468                -2.002_925_07e-03,
469                -1.676_185_74e-04,
470                -1.110_223_02e-16,
471                9.999_999_72e-09,
472                1.110_223_02e-16,
473                1.387_778_78e-16,
474                6.938_893_90e-17,
475                1.707_618_42e-17,
476                -6.920_259_18e-19,
477                -2.072_911_31e-18,
478                -5.059_828_46e-19,
479                -4.199_226_50e-20,
480                -1.110_223_02e-16,
481                -1.110_223_02e-16,
482                9.999_999_94e-09,
483                0.000_000_00e+00,
484                -5.551_115_12e-17,
485                -5.030_698_08e-17,
486                -1.057_097_12e-17,
487                7.377_656_97e-19,
488                3.917_957_51e-19,
489                3.472_752_04e-20,
490                -6.765_421_56e-17,
491                -2.775_557_56e-17,
492                3.330_669_07e-16,
493                1.000_000_04e-08,
494                1.110_223_02e-16,
495                0.000_000_00e+00,
496                1.561_251_13e-17,
497                1.713_039_43e-17,
498                4.150_037_93e-18,
499                3.445_372_69e-19,
500                -1.317_305_64e-17,
501                1.040_834_09e-17,
502                0.000_000_00e+00,
503                -1.110_223_02e-16,
504                9.999_999_94e-09,
505                0.000_000_00e+00,
506                -2.775_557_56e-17,
507                -2.081_668_17e-17,
508                -4.987_329_99e-18,
509                -4.154_696_61e-19,
510                3.449_672_76e-02,
511                7.676_151_38e-17,
512                7.806_255_64e-17,
513                0.000_000_00e+00,
514                0.000_000_00e+00,
515                2.663_127_02e-01,
516                0.000_000_00e+00,
517                -1.775_970_42e-01,
518                -5.699_341_56e-02,
519                -5.235_330_37e-03,
520                -2.629_524_45e-18,
521                -1.969_351_60e-18,
522                -3.415_236_84e-18,
523                -3.469_446_95e-18,
524                0.000_000_00e+00,
525                0.000_000_00e+00,
526                9.999_999_94e-09,
527                0.000_000_00e+00,
528                0.000_000_00e+00,
529                0.000_000_00e+00,
530                -7.750_552_24e-03,
531                -9.403_293_25e-18,
532                -3.769_720_13e-18,
533                9.378_348_79e-18,
534                1.387_778_78e-17,
535                -1.775_970_42e-01,
536                1.110_223_02e-16,
537                6.235_919_81e-01,
538                5.222_724_53e-01,
539                1.284_158_94e-01,
540                -2.002_925_07e-03,
541                -2.538_164_32e-18,
542                -2.373_717_19e-18,
543                4.438_452_64e-18,
544                -2.439_454_89e-18,
545                -5.699_341_56e-02,
546                0.000_000_00e+00,
547                5.222_724_53e-01,
548                9.811_305_76e-01,
549                6.049_809_83e-01,
550                -1.676_185_74e-04,
551                -3.012_464_45e-19,
552                -2.266_319_91e-19,
553                -6.048_013_77e-20,
554                -2.020_173_58e-19,
555                -5.235_330_37e-03,
556                0.000_000_00e+00,
557                1.284_158_94e-01,
558                6.049_809_83e-01,
559                9.998_727_40e-01,
560            ],
561        );
562
563        let cov = pred.cov();
564        assert!(cov.relative_eq(&expected_cov, 1E-7, 1E-7))
565    }
566
567    #[test]
568    fn log_marginal_a() {
569        let x_train: DMatrix<f64> =
570            DMatrix::from_column_slice(5, 1, &[-4.0, -3.0, -2.0, -1.0, 1.0]);
571        let y_train: DVector<f64> = x_train.map(|x| x.sin()).column(0).into();
572
573        let kernel = RBFKernel::default() * ConstantKernel::default();
574        let parameters = kernel.parameters();
575        assert!(&parameters.relative_eq(&dvector![0.0, 0.0], 1E-9, 1E-9));
576
577        let expected_ln_m = -5.029_140_040_847_684;
578        let expected_grad = dvector![2.068_285_41, -1.191_110_32];
579
580        let gp = GaussianProcess::train(
581            kernel,
582            x_train,
583            y_train,
584            NoiseModel::default(),
585        )
586        .unwrap();
587        // Without Gradient
588        assert::close(gp.ln_m(), expected_ln_m, 1E-7);
589
590        // With Gradient
591        let (ln_m, grad_ln_m) = gp.ln_m_with_params(&parameters).unwrap();
592        assert::close(ln_m, expected_ln_m, 1E-7);
593        assert!(grad_ln_m.relative_eq(&expected_grad, 1E-7, 1E-7));
594    }
595
596    #[test]
597    fn log_marginal_b() -> Result<(), KernelError> {
598        let x_train: DMatrix<f64> =
599            DMatrix::from_column_slice(5, 1, &[-4.0, -3.0, -2.0, -1.0, 1.0]);
600        let y_train: DVector<f64> = x_train.map(|x| x.sin()).column(0).into();
601
602        let kernel = RBFKernel::new(1.994_891_474_270_000_8)?
603            * ConstantKernel::new(1.221_163_421_070_665)?;
604        let parameters = kernel.parameters();
605        assert!(relative_eq(
606            &parameters,
607            &dvector![0.690_589_65, 0.199_804_03],
608            1E-7,
609            1E-7
610        ));
611
612        let expected_ln_m = -3.414_870_095_916_796;
613        let expected_grad = dvector![0.0, 0.0];
614
615        let gp = GaussianProcess::train(
616            kernel,
617            x_train,
618            y_train,
619            NoiseModel::default(),
620        )
621        .unwrap();
622        // Without Gradient
623        let ln_m = gp.ln_m();
624        assert::close(ln_m, expected_ln_m, 1E-7);
625
626        // With Gradient
627        let (ln_m, grad_ln_m) = gp.ln_m_with_params(&parameters).unwrap();
628        assert::close(ln_m, expected_ln_m, 1E-7);
629        assert!(grad_ln_m.relative_eq(&expected_grad, 1E-6, 1E-6));
630        Ok(())
631    }
632
633    #[test]
634    fn optimize_gp_1_param() {
635        let x_train: DMatrix<f64> =
636            DMatrix::from_column_slice(5, 1, &[-4.0, -3.0, -2.0, -1.0, 1.0]);
637        let y_train: DVector<f64> = x_train.map(|x| x.sin()).column(0).into();
638
639        let kernel = RBFKernel::default();
640        let noise_model = NoiseModel::default();
641
642        let gp = GaussianProcess::train(kernel, x_train, y_train, noise_model)
643            .unwrap();
644
645        let mut rng = Xoshiro256Plus::seed_from_u64(0xABCD);
646        let gp = gp.optimize(100, 10, &mut rng).expect("Failed to optimize");
647        let opt_params = gp.kernel().parameters();
648
649        assert!(opt_params.relative_eq(&dvector![0.657_854_21], 1E-5, 1E-5));
650        assert::close(gp.ln_m(), -3.444_937_833_462_115, 1E-7);
651        assert::close(
652            gp.ln_m(),
653            gp.ln_m_with_params(&gp.kernel().parameters()).unwrap().0,
654            1E-7,
655        );
656    }
657
658    #[test]
659    fn optimize_gp_2_param() {
660        let x_train: DMatrix<f64> =
661            DMatrix::from_column_slice(5, 1, &[-4.0, -3.0, -2.0, -1.0, 1.0]);
662        let y_train: DVector<f64> = x_train.map(|x| x.sin()).column(0).into();
663
664        let kernel = ConstantKernel::default() * RBFKernel::default();
665        let noise_model = NoiseModel::default();
666
667        let gp = GaussianProcess::train(kernel, x_train, y_train, noise_model)
668            .unwrap();
669
670        let mut rng = Xoshiro256Plus::seed_from_u64(0xABCD);
671        let gp = gp.optimize(200, 30, &mut rng).expect("Failed to optimize");
672        let opt_params = gp.kernel().parameters();
673
674        assert!(opt_params.relative_eq(
675            &dvector![0.199_804_03, 0.690_589_65],
676            1E-5,
677            1E-5
678        ));
679
680        assert::close(gp.ln_m(), -3.414_870_095_916_796, 1E-7);
681        assert::close(
682            gp.ln_m(),
683            gp.ln_m_with_params(&gp.kernel().parameters()).unwrap().0,
684            1E-7,
685        );
686    }
687
688    #[test]
689    fn no_noise_k_chol() -> Result<(), KernelError> {
690        let xs: DMatrix<f64> =
691            DMatrix::from_column_slice(6, 1, &[1., 3., 5., 6., 7., 8.]);
692        let ys: DVector<f64> = xs.map(|x| x * x.sin()).column(0).into();
693
694        let kernel: ProductKernel<ConstantKernel, RBFKernel> =
695            (ConstantKernel::new_unchecked(1.0)
696                * RBFKernel::new_unchecked(1.0))
697            .reparameterize(&[3.099_752_67, 0.516_338_23])?;
698        let gp =
699            GaussianProcess::train(kernel, xs, ys, NoiseModel::Uniform(0.0))
700                .expect("Should produce GP");
701        let expected_k_chol: DMatrix<f64> = DMatrix::from_row_slice(
702            6,
703            6,
704            &[
705                4.710_887_58e+00,
706                0.000_000_00e+00,
707                0.000_000_00e+00,
708                0.000_000_00e+00,
709                0.000_000_00e+00,
710                0.000_000_00e+00,
711                2.311_209_28e+00,
712                4.104_969_36e+00,
713                0.000_000_00e+00,
714                0.000_000_00e+00,
715                0.000_000_00e+00,
716                0.000_000_00e+00,
717                2.729_284_89e-01,
718                2.498_691_55e+00,
719                3.984_283_17e+00,
720                0.000_000_00e+00,
721                0.000_000_00e+00,
722                0.000_000_00e+00,
723                5.498_016_88e-02,
724                1.058_107_06e+00,
725                3.994_303_01e+00,
726                2.261_723_20e+00,
727                0.000_000_00e+00,
728                0.000_000_00e+00,
729                7.757_674_14e-03,
730                3.088_465_97e-01,
731                2.538_478_56e+00,
732                3.584_280_88e+00,
733                1.675_133_57e+00,
734                0.000_000_00e+00,
735                7.666_996_49e-04,
736                6.266_390_03e-02,
737                1.082_699_33e+00,
738                2.872_531_28e+00,
739                3.289_048_54e+00,
740                1.395_356_72e+00,
741            ],
742        );
743
744        assert!(gp.k_chol().l().relative_eq(&expected_k_chol, 1E-8, 1E-8));
745        Ok(())
746    }
747
748    #[test]
749    fn noisy_k_chol() -> Result<(), KernelError> {
750        let xs: DMatrix<f64> =
751            DMatrix::from_column_slice(6, 1, &[1., 3., 5., 6., 7., 8.]);
752        let ys: DVector<f64> = xs.map(|x| x * x.sin()).column(0).into();
753        let dy = DVector::from_row_slice(&[
754            0.917_022,
755            1.220_324_49,
756            0.500_114_37,
757            0.802_332_57,
758            0.646_755_89,
759            0.592_338_59,
760        ]);
761
762        let ys = &ys + &dy;
763
764        let kernel: ProductKernel<ConstantKernel, RBFKernel> =
765            (ConstantKernel::new_unchecked(1.0)
766                * RBFKernel::new_unchecked(1.0))
767            .reparameterize(&[2.886_720_93, -0.033_327_73])?;
768        let gp = GaussianProcess::train(
769            kernel,
770            xs,
771            ys,
772            NoiseModel::PerPoint(dy.map(|x| x * x)),
773        )
774        .expect("Should produce GP");
775        let expected_k_chol: DMatrix<f64> = DMatrix::from_row_slice(
776            6,
777            6,
778            &[
779                4.333_051_38e+00,
780                0.000_000_00e+00,
781                0.000_000_00e+00,
782                0.000_000_00e+00,
783                0.000_000_00e+00,
784                0.000_000_00e+00,
785                4.880_168_69e-01,
786                4.380_118_30e+00,
787                0.000_000_00e+00,
788                0.000_000_00e+00,
789                0.000_000_00e+00,
790                0.000_000_00e+00,
791                7.999_446_59e-04,
792                4.826_837_17e-01,
793                4.236_925_19e+00,
794                0.000_000_00e+00,
795                0.000_000_00e+00,
796                0.000_000_00e+00,
797                6.516_715_70e-06,
798                3.335_496_65e-02,
799                2.476_599_40e+00,
800                3.527_532_47e+00,
801                0.000_000_00e+00,
802                0.000_000_00e+00,
803                1.822_923_56e-08,
804                7.913_467_57e-04,
805                4.989_987_09e-01,
806                2.628_868_78e+00,
807                3.345_556_26e+00,
808                0.000_000_00e+00,
809                1.750_971_16e-11,
810                6.446_687_85e-06,
811                3.448_226_27e-02,
812                5.752_472_07e-01,
813                2.684_100_80e+00,
814                3.278_532_35e+00,
815            ],
816        );
817
818        assert!(gp.k_chol().l().relative_eq(&expected_k_chol, 1E-7, 1E-7));
819        Ok(())
820    }
821}