Skip to main content

egobox_gp/
correlation_models.rs

1//! A module for correlation models with PLS weighting to model the error term of the GP model.
2//!
3//! The following correlation models are implemented:
4//! * squared exponential,
5//! * absolute exponential,
6//! * matern 3/2,
7//! * matern 5/2.
8
9use crate::utils::differences;
10use linfa::Float;
11use ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2, Zip};
12#[cfg(feature = "serializable")]
13use serde::{Deserialize, Serialize};
14use std::convert::TryFrom;
15use std::fmt;
16
17/// A trait for using a correlation model in GP regression
18pub trait CorrelationModel<F: Float>: Clone + Copy + Default + fmt::Display + Sync {
19    /// Compute correlation function r(x, x') given x and a set of `x'` training samples, aka `xtrain`
20    /// `theta` parameters, and PLS `weights` with:
21    ///
22    /// * `x`      : point at which to compute correlation (shape nx)
23    /// * `xtrain` : training samples (shape nt x nx)
24    ///   where nx is the dimension of x and nt is the number of training samples (aka xtrain.nrows()).
25    /// * `theta`   : hyperparameters (shape 1 x nx)
26    /// * `weights` : PLS weights (shape nx x h) where h is the reduced dimension when PLS is used (kpls_dim).
27    ///
28    /// The returned correlation function matrix has shape (nt x 1) and corresponds to r(x, xtrain)
29    /// where r is the correlation function defined by the model.
30    fn rval(
31        &self,
32        x: &ArrayBase<impl Data<Elem = F>, Ix1>,
33        xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
34        theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
35        weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
36    ) -> Array2<F> {
37        let d = differences(x, xtrain);
38        self.rval_from_distances(&d, theta, weights)
39    }
40
41    /// Compute correlation function r(x, x') given distances `distances` between x and x',
42    /// `theta` parameters, and PLS `weights` with:
43    ///
44    /// * `distances`     : distances (nxd)
45    /// * `theta`   : hyperparameters (d,)
46    /// * `weights` : PLS weights (dxh)
47    ///   where d is the initial dimension and h (<d) is the reduced dimension when PLS is used (kpls_dim)
48    ///
49    /// The returned correlation function matrix has shape (nt x 1) and corresponds to r(x, xtrain)
50    /// where r is the correlation function defined by the model.
51    fn rval_from_distances(
52        &self,
53        distances: &ArrayBase<impl Data<Elem = F>, Ix2>,
54        theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
55        weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
56    ) -> Array2<F>;
57
58    /// Compute gradients of `r(x, x')` at given `x` given a set of `x'` training samples, aka `xtrain`,
59    /// `theta` parameters, and PLS `weights`.
60    /// The returned jacobian matrix is dr/dx where r is the correlation function vector between x and xtrain (shape nt).
61    /// Gradients are computed with respect to `x` and returned as a matrix of shape (nt, nx)
62    /// where nt is the number of training samples (aka xtrain.nrows()) and nx is the dimension of x.
63    fn jac(
64        &self,
65        x: &ArrayBase<impl Data<Elem = F>, Ix1>,
66        xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
67        theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
68        weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
69    ) -> Array2<F>;
70
71    /// Compute both the correlation function matrix `r(x, x')` and its jacobian at given `x`
72    /// given a set of `xtrain` training samples, `theta` parameters, and PLS `weights`.
73    /// Used to avoid redundant computations when both correlation and jacobian are needed.
74    fn rval_with_jac(
75        &self,
76        x: &ArrayBase<impl Data<Elem = F>, Ix1>,
77        xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
78        theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
79        weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
80    ) -> (Array2<F>, Array2<F>);
81
82    /// Returns the theta influence factors for the correlation model.
83    /// See <https://hal.science/hal-03812073v2/document>
84    fn theta_influence_factors(&self) -> (F, F) {
85        (F::one(), F::one())
86    }
87}
88
89/// Squared exponential correlation models
90#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
91#[cfg_attr(
92    feature = "serializable",
93    derive(Serialize, Deserialize),
94    serde(into = "String"),
95    serde(try_from = "String")
96)]
97pub struct SquaredExponentialCorr();
98
99impl From<SquaredExponentialCorr> for String {
100    fn from(_item: SquaredExponentialCorr) -> String {
101        "SquaredExponential".to_string()
102    }
103}
104
105impl TryFrom<String> for SquaredExponentialCorr {
106    type Error = &'static str;
107    fn try_from(s: String) -> Result<Self, Self::Error> {
108        if s == "SquaredExponential" {
109            Ok(Self::default())
110        } else {
111            Err("Bad string value for SquaredExponentialCorr, should be \'SquaredExponential\'")
112        }
113    }
114}
115
116impl<F: Float> CorrelationModel<F> for SquaredExponentialCorr {
117    ///   d    h
118    /// prod prod exp( - |theta_l * weight_j_l * d_j|^2 / 2 )
119    ///  j=1  l=1
120    fn rval_from_distances(
121        &self,
122        d: &ArrayBase<impl Data<Elem = F>, Ix2>,
123        theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
124        weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
125    ) -> Array2<F> {
126        let theta_w_sq = (theta * weights).mapv(|v| v * v).sum_axis(Axis(1));
127        let r = d.mapv(|v| v * v).dot(&theta_w_sq);
128        r.mapv(|v| F::exp(F::cast(-0.5) * v))
129            .into_shape_with_order((d.nrows(), 1))
130            .unwrap()
131    }
132
133    fn jac(
134        &self,
135        x: &ArrayBase<impl Data<Elem = F>, Ix1>,
136        xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
137        theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
138        weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
139    ) -> Array2<F> {
140        let d = differences(x, xtrain);
141        let neg_theta_w_sq = (theta * weights).mapv(|v| -(v * v)).sum_axis(Axis(1));
142        let r = {
143            let exponent = d.mapv(|v| v * v).dot(&neg_theta_w_sq.mapv(|v| -v));
144            exponent
145                .mapv(|v| F::exp(F::cast(-0.5) * v))
146                .into_shape_with_order((d.nrows(), 1))
147                .unwrap()
148        };
149        d * &neg_theta_w_sq * &r
150    }
151
152    fn rval_with_jac(
153        &self,
154        x: &ArrayBase<impl Data<Elem = F>, Ix1>,
155        xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
156        theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
157        weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
158    ) -> (Array2<F>, Array2<F>) {
159        let d = differences(x, xtrain);
160        let neg_theta_w_sq = (theta * weights).mapv(|v| -(v * v)).sum_axis(Axis(1));
161        let r = {
162            let exponent = d.mapv(|v| v * v).dot(&neg_theta_w_sq.mapv(|v| -v));
163            exponent
164                .mapv(|v| F::exp(F::cast(-0.5) * v))
165                .into_shape_with_order((d.nrows(), 1))
166                .unwrap()
167        };
168        let jr = d * &neg_theta_w_sq * &r;
169        (r, jr)
170    }
171
172    fn theta_influence_factors(&self) -> (F, F) {
173        (F::cast(0.29), F::cast(1.96))
174    }
175}
176
177impl fmt::Display for SquaredExponentialCorr {
178    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
179        write!(f, "SquaredExponential")
180    }
181}
182
183/// Absolute exponential correlation models
184#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
185#[cfg_attr(
186    feature = "serializable",
187    derive(Serialize, Deserialize),
188    serde(into = "String"),
189    serde(try_from = "String")
190)]
191pub struct AbsoluteExponentialCorr();
192
193impl From<AbsoluteExponentialCorr> for String {
194    fn from(_item: AbsoluteExponentialCorr) -> String {
195        "AbsoluteExponential".to_string()
196    }
197}
198
199impl TryFrom<String> for AbsoluteExponentialCorr {
200    type Error = &'static str;
201    fn try_from(s: String) -> Result<Self, Self::Error> {
202        if s == "AbsoluteExponential" {
203            Ok(Self::default())
204        } else {
205            Err("Bad string value for AbsoluteExponentialCorr, should be \'AbsoluteExponential\'")
206        }
207    }
208}
209
210impl<F: Float> CorrelationModel<F> for AbsoluteExponentialCorr {
211    ///   d    h
212    /// prod prod exp( - theta_l * |weight_j_l * d_j| )
213    ///  j=1  l=1
214    fn rval_from_distances(
215        &self,
216        d: &ArrayBase<impl Data<Elem = F>, Ix2>,
217        theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
218        weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
219    ) -> Array2<F> {
220        let theta_w = weights.mapv(|v| v.abs()).dot(theta);
221        let r = d.mapv(|v| v.abs()).dot(&theta_w);
222        r.mapv(|v| F::exp(-v))
223            .into_shape_with_order((d.nrows(), 1))
224            .unwrap()
225    }
226
227    fn jac(
228        &self,
229        x: &ArrayBase<impl Data<Elem = F>, Ix1>,
230        xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
231        theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
232        weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
233    ) -> Array2<F> {
234        let d = differences(x, xtrain);
235        let r = self.rval_from_distances(&d, theta, weights);
236        let sign_d = d.mapv(|v| v.signum());
237
238        let dtheta_w = sign_d
239            * (theta * weights.mapv(|v| v.abs()))
240                .sum_axis(Axis(1))
241                .mapv(|v| F::cast(-1.) * v);
242        &dtheta_w * &r
243    }
244
245    fn rval_with_jac(
246        &self,
247        x: &ArrayBase<impl Data<Elem = F>, Ix1>,
248        xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
249        theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
250        weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
251    ) -> (Array2<F>, Array2<F>) {
252        let d = differences(x, xtrain);
253        let neg_theta_w = (theta * weights.mapv(|v| v.abs()))
254            .sum_axis(Axis(1))
255            .mapv(|v| -v);
256        let r = {
257            let exponent = d.mapv(|v| v.abs()).dot(&neg_theta_w.mapv(|v| -v));
258            exponent
259                .mapv(|v| F::exp(-v))
260                .into_shape_with_order((d.nrows(), 1))
261                .unwrap()
262        };
263        let jr = &(d.mapv(|v| v.signum()) * &neg_theta_w) * &r;
264        (r, jr)
265    }
266
267    fn theta_influence_factors(&self) -> (F, F) {
268        (F::cast(0.15), F::cast(3.76))
269    }
270}
271
272impl fmt::Display for AbsoluteExponentialCorr {
273    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
274        write!(f, "AbsoluteExponential")
275    }
276}
277
278/// Matern 3/2 correlation model
279#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
280#[cfg_attr(
281    feature = "serializable",
282    derive(Serialize, Deserialize),
283    serde(into = "String"),
284    serde(try_from = "String")
285)]
286pub struct Matern32Corr();
287
288impl From<Matern32Corr> for String {
289    fn from(_item: Matern32Corr) -> String {
290        "Matern32".to_string()
291    }
292}
293
294impl TryFrom<String> for Matern32Corr {
295    type Error = &'static str;
296    fn try_from(s: String) -> Result<Self, Self::Error> {
297        if s == "Matern32" {
298            Ok(Self::default())
299        } else {
300            Err("Bad string value for Matern32Corr, should be \'Matern32\'")
301        }
302    }
303}
304
305impl<F: Float> CorrelationModel<F> for Matern32Corr {
306    ///   d    h         
307    /// prod prod (1 + sqrt(3) * theta_l * |d_j . weight_j|) exp( - sqrt(3) * theta_l * |d_j . weight_j| )
308    ///  j=1  l=1
309    fn rval_from_distances(
310        &self,
311        d: &ArrayBase<impl Data<Elem = F>, Ix2>,
312        theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
313        weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
314    ) -> Array2<F> {
315        let sqrt3 = F::cast(3.).sqrt();
316        let theta_w = theta * weights.mapv(|v| v.abs());
317
318        let mut r = Array1::zeros(d.nrows());
319        Zip::from(&mut r).and(d.rows()).for_each(|r_i, d_i| {
320            let mut a = F::one();
321            let mut b_sum = F::zero();
322            Zip::from(&d_i).and(theta_w.rows()).for_each(|&d_ij, tw_j| {
323                let abs_d = d_ij.abs();
324                let mut prod = F::one();
325                for &tw in tw_j.iter() {
326                    prod *= F::one() + sqrt3 * tw * abs_d;
327                    b_sum += tw * abs_d;
328                }
329                a *= prod;
330            });
331            *r_i = a * F::exp(-sqrt3 * b_sum);
332        });
333        r.into_shape_with_order((d.nrows(), 1)).unwrap()
334    }
335
336    fn jac(
337        &self,
338        x: &ArrayBase<impl Data<Elem = F>, Ix1>,
339        xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
340        theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
341        weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
342    ) -> Array2<F> {
343        let d = differences(x, xtrain);
344        let r = self.rval_from_distances(&d, theta, weights);
345        self._jac_from_r(&d, &r, theta, weights)
346    }
347
348    fn rval_with_jac(
349        &self,
350        x: &ArrayBase<impl Data<Elem = F>, Ix1>,
351        xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
352        theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
353        weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
354    ) -> (Array2<F>, Array2<F>) {
355        let d = differences(x, xtrain);
356        let r = self.rval_from_distances(&d, theta, weights);
357        let jr = self._jac_from_r(&d, &r, theta, weights);
358        (r, jr)
359    }
360
361    fn theta_influence_factors(&self) -> (F, F) {
362        (F::cast(0.21), F::cast(2.74))
363    }
364}
365
366impl fmt::Display for Matern32Corr {
367    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
368        write!(f, "Matern32")
369    }
370}
371
372impl Matern32Corr {
373    /// Compute the jacobian dr/dx from precomputed distances and correlation values.
374    ///
375    /// For Matern 3/2, f(u) = 1 + √3·u is always positive, so the
376    /// "product-excluding-one-factor" can be computed via division, reducing
377    /// the O(n·d²·h²) nested loop to O(n·d·h).
378    fn _jac_from_r<F: Float>(
379        &self,
380        d: &ArrayBase<impl Data<Elem = F>, Ix2>,
381        r: &ArrayBase<impl Data<Elem = F>, Ix2>,
382        theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
383        weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
384    ) -> Array2<F> {
385        let three = F::cast(3.);
386        let sqrt3 = three.sqrt();
387        let neg3 = F::cast(-3.);
388        let theta_w = theta * weights.mapv(|v| v.abs());
389
390        let mut jr = Array2::zeros((d.nrows(), d.ncols()));
391        Zip::from(jr.rows_mut())
392            .and(d.rows())
393            .and(r.column(0))
394            .for_each(|mut jr_i, d_i, &r_i| {
395                Zip::from(&mut jr_i).and(&d_i).and(theta_w.rows()).for_each(
396                    |jr_ij, &d_ij, tw_j| {
397                        let abs_d = d_ij.abs();
398                        let sign_d = d_ij.signum();
399                        let mut sum = F::zero();
400                        for &tw in tw_j.iter() {
401                            let f = F::one() + sqrt3 * tw * abs_d;
402                            sum += tw * tw * abs_d / f;
403                        }
404                        *jr_ij = neg3 * sign_d * r_i * sum;
405                    },
406                );
407            });
408        jr
409    }
410}
411
412/// Matern 5/2 correlation model
413#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
414#[cfg_attr(
415    feature = "serializable",
416    derive(Serialize, Deserialize),
417    serde(into = "String"),
418    serde(try_from = "String")
419)]
420pub struct Matern52Corr();
421
422impl From<Matern52Corr> for String {
423    fn from(_item: Matern52Corr) -> String {
424        "Matern52".to_string()
425    }
426}
427
428impl TryFrom<String> for Matern52Corr {
429    type Error = &'static str;
430    fn try_from(s: String) -> Result<Self, Self::Error> {
431        if s == "Matern52" {
432            Ok(Self::default())
433        } else {
434            Err("Bad string value for Matern52Corr, should be \'Matern52\'")
435        }
436    }
437}
438
439impl<F: Float> CorrelationModel<F> for Matern52Corr {
440    ///   d    h         
441    /// prod prod (1 + sqrt(5) * theta_l * |d_j . weight_j| + (5./3.) * theta_l^2 * |d_j . weight_j|^2) exp( - sqrt(5) * theta_l * |d_j . weight_j| )
442    ///  j=1  l=1
443    fn rval_from_distances(
444        &self,
445        d: &ArrayBase<impl Data<Elem = F>, Ix2>,
446        theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
447        weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
448    ) -> Array2<F> {
449        let sqrt5 = F::cast(5.).sqrt();
450        let div5_3 = F::cast(5. / 3.);
451        let theta_w = theta * weights.mapv(|v| v.abs());
452
453        let mut r = Array1::zeros(d.nrows());
454        Zip::from(&mut r).and(d.rows()).for_each(|r_i, d_i| {
455            let mut a = F::one();
456            let mut b_sum = F::zero();
457            Zip::from(&d_i).and(theta_w.rows()).for_each(|&d_ij, tw_j| {
458                let abs_d = d_ij.abs();
459                let mut prod = F::one();
460                for &tw in tw_j.iter() {
461                    let u = tw * abs_d;
462                    prod *= F::one() + sqrt5 * u + div5_3 * u * u;
463                    b_sum += tw * abs_d;
464                }
465                a *= prod;
466            });
467            *r_i = a * F::exp(-sqrt5 * b_sum);
468        });
469        r.into_shape_with_order((d.nrows(), 1)).unwrap()
470    }
471
472    fn jac(
473        &self,
474        x: &ArrayBase<impl Data<Elem = F>, Ix1>,
475        xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
476        theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
477        weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
478    ) -> Array2<F> {
479        let d = differences(x, xtrain);
480        let r = self.rval_from_distances(&d, theta, weights);
481        self._jac_from_r(&d, &r, theta, weights)
482    }
483
484    fn rval_with_jac(
485        &self,
486        x: &ArrayBase<impl Data<Elem = F>, Ix1>,
487        xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
488        theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
489        weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
490    ) -> (Array2<F>, Array2<F>) {
491        let d = differences(x, xtrain);
492        let r = self.rval_from_distances(&d, theta, weights);
493        let jr = self._jac_from_r(&d, &r, theta, weights);
494        (r, jr)
495    }
496
497    fn theta_influence_factors(&self) -> (F, F) {
498        (F::cast(0.23), F::cast(2.44))
499    }
500}
501
502impl fmt::Display for Matern52Corr {
503    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
504        write!(f, "Matern52")
505    }
506}
507
508impl Matern52Corr {
509    /// Compute the jacobian dr/dx from precomputed distances and correlation values.
510    ///
511    /// Uses the algebraic identity: combining the db (exponential derivative) and da
512    /// (polynomial derivative) terms yields a closed-form O(n·d·h) formula, replacing
513    /// the O(n·d²·h²) "product-excluding-one-factor" loop. This is possible because
514    /// the Matern 5/2 polynomial f(u) = 1 + √5u + 5/3·u² is always positive
515    /// (negative discriminant), so division by f is safe.
516    fn _jac_from_r<F: Float>(
517        &self,
518        d: &ArrayBase<impl Data<Elem = F>, Ix2>,
519        r: &ArrayBase<impl Data<Elem = F>, Ix2>,
520        theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
521        weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
522    ) -> Array2<F> {
523        let sqrt5 = F::cast(5.).sqrt();
524        let div5_3 = F::cast(5. / 3.);
525        let neg5_3 = F::cast(-5. / 3.);
526        let theta_w = theta * weights.mapv(|v| v.abs());
527
528        let mut jr = Array2::zeros((d.nrows(), d.ncols()));
529        Zip::from(jr.rows_mut())
530            .and(d.rows())
531            .and(r.column(0))
532            .for_each(|mut jr_i, d_i, &r_i| {
533                Zip::from(&mut jr_i).and(&d_i).and(theta_w.rows()).for_each(
534                    |jr_ij, &d_ij, tw_j| {
535                        let abs_d = d_ij.abs();
536                        let sign_d = d_ij.signum();
537                        let mut sum = F::zero();
538                        for &tw in tw_j.iter() {
539                            let u = tw * abs_d;
540                            let f = F::one() + sqrt5 * u + div5_3 * u * u;
541                            sum += tw * tw * abs_d * (F::one() + sqrt5 * u) / f;
542                        }
543                        *jr_ij = neg5_3 * sign_d * r_i * sum;
544                    },
545                );
546            });
547        jr
548    }
549}
550
551#[cfg(test)]
552mod tests {
553    use super::*;
554    use crate::utils::{DiffMatrix, NormalizedData};
555    use approx::assert_abs_diff_eq;
556    use ndarray::{arr1, array};
557    use paste::paste;
558
559    #[test]
560    fn test_squared_exponential() {
561        let xt = array![[4.5], [1.2], [2.0], [3.0], [4.0]];
562        let dm = DiffMatrix::new(&xt);
563        let res = SquaredExponentialCorr::default().rval_from_distances(
564            &dm.d,
565            &arr1(&[f64::sqrt(0.2)]),
566            &array![[1.]],
567        );
568        let expected = array![
569            [0.336552878364737],
570            [0.5352614285189903],
571            [0.7985162187593771],
572            [0.9753099120283326],
573            [0.9380049995307295],
574            [0.7232502423798424],
575            [0.4565760496233148],
576            [0.9048374180359595],
577            [0.6703200460356393],
578            [0.9048374180359595]
579        ];
580        assert_abs_diff_eq!(res, expected, epsilon = 1e-6);
581    }
582
583    #[test]
584    fn test_squared_exponential_2d() {
585        let xt = array![[0., 1.], [2., 3.], [4., 5.]];
586        let dm = DiffMatrix::new(&xt);
587        dbg!(&dm);
588        let res = SquaredExponentialCorr::default().rval_from_distances(
589            &dm.d,
590            &arr1(&[f64::sqrt(2.), 2.]),
591            &array![[1., 0.], [0., 1.]],
592        );
593        let expected = array![[6.14421235e-06], [1.42516408e-21], [6.14421235e-06]];
594        assert_abs_diff_eq!(res, expected, epsilon = 1e-6);
595    }
596
597    #[test]
598    fn test_matern32_2d() {
599        let xt = array![[0., 1.], [2., 3.], [4., 5.]];
600        let dm = DiffMatrix::new(&xt);
601        dbg!(&dm);
602        let res = Matern32Corr::default().rval_from_distances(
603            &dm.d,
604            &arr1(&[1., 2.]),
605            &array![[1., 0.], [0., 1.]],
606        );
607        let expected = array![[1.08539595e-03], [1.10776401e-07], [1.08539595e-03]];
608        assert_abs_diff_eq!(res, expected, epsilon = 1e-6);
609    }
610
611    macro_rules! test_correlation {
612        ($corr:ident, $kpls:expr_2021) => {
613            paste! {
614                #[test]
615                fn [<test_corr_ $corr:lower _kpls_ $kpls _derivatives>]() {
616                    let x = array![3., 5.];
617                    let xt = array![
618                        [-9.375, -5.625],
619                        [-5.625, -4.375],
620                        [9.375, 1.875],
621                        [8.125, 5.625],
622                        [-4.375, -0.625],
623                        [6.875, -3.125],
624                        [4.375, 9.375],
625                        [3.125, 4.375],
626                        [5.625, -8.125],
627                        [-8.125, 3.125],
628                        [1.875, -6.875],
629                        [-0.625, 8.125],
630                        [-1.875, -1.875],
631                        [0.625, 0.625],
632                        [-6.875, -9.375],
633                        [-3.125, 6.875]
634                    ];
635                    let xtrain = NormalizedData::new(&xt);
636                    let xnorm = (x.to_owned() - &xtrain.mean) / &xtrain.std;
637                    let (theta, weights) = if $kpls {
638                        (array![0.31059002],
639                            array![[-0.02701716],
640                            [-0.99963497]])
641                    } else {
642                        (array![0.34599115925909146, 0.32083374253611624],
643                         array![[1., 0.], [0., 1.]])
644                    };
645
646                    let corr = [< $corr Corr >]::default();
647                    let jac = corr.jac(&xnorm, &xtrain.data, &theta, &weights) / &xtrain.std;
648                    println!("Jacobian: \n{:?}", jac);
649                    let xa: f64 = x[0];
650                    let xb: f64 = x[1];
651                    let e = 1e-5;
652                    let x = array![
653                        [xa, xb],
654                        [xa + e, xb],
655                        [xa - e, xb],
656                        [xa, xb + e],
657                        [xa, xb - e]
658                    ];
659
660                    let mut rxx = Array2::zeros((xtrain.data.nrows(), x.nrows()));
661                    Zip::from(rxx.columns_mut())
662                        .and(x.rows())
663                        .for_each(|mut rxxi, xi| {
664                            let xnorm = (xi.to_owned() - &xtrain.mean) / &xtrain.std;
665                            let d = differences(&xnorm, &xtrain.data);
666                            rxxi.assign(&(corr.rval_from_distances( &d, &theta, &weights).column(0)));
667                        });
668                    let fdiffa = (rxx.column(1).to_owned() - rxx.column(2)).mapv(|v| v / (2. * e));
669                    assert_abs_diff_eq!(fdiffa, jac.column(0), epsilon=1e-6);
670                    let fdiffb = (rxx.column(3).to_owned() - rxx.column(4)).mapv(|v| v / (2. * e));
671                    assert_abs_diff_eq!(fdiffb, jac.column(1), epsilon=1e-6);
672                }
673            }
674        };
675    }
676
677    test_correlation!(SquaredExponential, false);
678    test_correlation!(AbsoluteExponential, false);
679    test_correlation!(Matern32, false);
680    test_correlation!(Matern52, false);
681    test_correlation!(SquaredExponential, true);
682    test_correlation!(AbsoluteExponential, true);
683    test_correlation!(Matern32, true);
684    test_correlation!(Matern52, true);
685
686    #[test]
687    fn test_matern52_2d() {
688        let xt = array![[0., 1.], [2., 3.], [4., 5.]];
689        let dm = DiffMatrix::new(&xt);
690        let res = Matern52Corr::default().rval_from_distances(
691            &dm.d,
692            &arr1(&[1., 2.]),
693            &array![[1., 0.], [0., 1.]],
694        );
695        let expected = array![[6.62391590e-04], [1.02117882e-08], [6.62391590e-04]];
696        assert_abs_diff_eq!(res, expected, epsilon = 1e-6);
697    }
698}