1use crate::utils::differences;
10use linfa::Float;
11use ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2, Zip};
12#[cfg(feature = "serializable")]
13use serde::{Deserialize, Serialize};
14use std::convert::TryFrom;
15use std::fmt;
16
17pub trait CorrelationModel<F: Float>: Clone + Copy + Default + fmt::Display + Sync {
19 fn rval(
31 &self,
32 x: &ArrayBase<impl Data<Elem = F>, Ix1>,
33 xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
34 theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
35 weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
36 ) -> Array2<F> {
37 let d = differences(x, xtrain);
38 self.rval_from_distances(&d, theta, weights)
39 }
40
41 fn rval_from_distances(
52 &self,
53 distances: &ArrayBase<impl Data<Elem = F>, Ix2>,
54 theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
55 weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
56 ) -> Array2<F>;
57
58 fn jac(
64 &self,
65 x: &ArrayBase<impl Data<Elem = F>, Ix1>,
66 xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
67 theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
68 weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
69 ) -> Array2<F>;
70
71 fn rval_with_jac(
75 &self,
76 x: &ArrayBase<impl Data<Elem = F>, Ix1>,
77 xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
78 theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
79 weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
80 ) -> (Array2<F>, Array2<F>);
81
82 fn theta_influence_factors(&self) -> (F, F) {
85 (F::one(), F::one())
86 }
87}
88
89#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
91#[cfg_attr(
92 feature = "serializable",
93 derive(Serialize, Deserialize),
94 serde(into = "String"),
95 serde(try_from = "String")
96)]
97pub struct SquaredExponentialCorr();
98
99impl From<SquaredExponentialCorr> for String {
100 fn from(_item: SquaredExponentialCorr) -> String {
101 "SquaredExponential".to_string()
102 }
103}
104
105impl TryFrom<String> for SquaredExponentialCorr {
106 type Error = &'static str;
107 fn try_from(s: String) -> Result<Self, Self::Error> {
108 if s == "SquaredExponential" {
109 Ok(Self::default())
110 } else {
111 Err("Bad string value for SquaredExponentialCorr, should be \'SquaredExponential\'")
112 }
113 }
114}
115
116impl<F: Float> CorrelationModel<F> for SquaredExponentialCorr {
117 fn rval_from_distances(
121 &self,
122 d: &ArrayBase<impl Data<Elem = F>, Ix2>,
123 theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
124 weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
125 ) -> Array2<F> {
126 let theta_w_sq = (theta * weights).mapv(|v| v * v).sum_axis(Axis(1));
127 let r = d.mapv(|v| v * v).dot(&theta_w_sq);
128 r.mapv(|v| F::exp(F::cast(-0.5) * v))
129 .into_shape_with_order((d.nrows(), 1))
130 .unwrap()
131 }
132
133 fn jac(
134 &self,
135 x: &ArrayBase<impl Data<Elem = F>, Ix1>,
136 xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
137 theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
138 weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
139 ) -> Array2<F> {
140 let d = differences(x, xtrain);
141 let neg_theta_w_sq = (theta * weights).mapv(|v| -(v * v)).sum_axis(Axis(1));
142 let r = {
143 let exponent = d.mapv(|v| v * v).dot(&neg_theta_w_sq.mapv(|v| -v));
144 exponent
145 .mapv(|v| F::exp(F::cast(-0.5) * v))
146 .into_shape_with_order((d.nrows(), 1))
147 .unwrap()
148 };
149 d * &neg_theta_w_sq * &r
150 }
151
152 fn rval_with_jac(
153 &self,
154 x: &ArrayBase<impl Data<Elem = F>, Ix1>,
155 xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
156 theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
157 weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
158 ) -> (Array2<F>, Array2<F>) {
159 let d = differences(x, xtrain);
160 let neg_theta_w_sq = (theta * weights).mapv(|v| -(v * v)).sum_axis(Axis(1));
161 let r = {
162 let exponent = d.mapv(|v| v * v).dot(&neg_theta_w_sq.mapv(|v| -v));
163 exponent
164 .mapv(|v| F::exp(F::cast(-0.5) * v))
165 .into_shape_with_order((d.nrows(), 1))
166 .unwrap()
167 };
168 let jr = d * &neg_theta_w_sq * &r;
169 (r, jr)
170 }
171
172 fn theta_influence_factors(&self) -> (F, F) {
173 (F::cast(0.29), F::cast(1.96))
174 }
175}
176
177impl fmt::Display for SquaredExponentialCorr {
178 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
179 write!(f, "SquaredExponential")
180 }
181}
182
183#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
185#[cfg_attr(
186 feature = "serializable",
187 derive(Serialize, Deserialize),
188 serde(into = "String"),
189 serde(try_from = "String")
190)]
191pub struct AbsoluteExponentialCorr();
192
193impl From<AbsoluteExponentialCorr> for String {
194 fn from(_item: AbsoluteExponentialCorr) -> String {
195 "AbsoluteExponential".to_string()
196 }
197}
198
199impl TryFrom<String> for AbsoluteExponentialCorr {
200 type Error = &'static str;
201 fn try_from(s: String) -> Result<Self, Self::Error> {
202 if s == "AbsoluteExponential" {
203 Ok(Self::default())
204 } else {
205 Err("Bad string value for AbsoluteExponentialCorr, should be \'AbsoluteExponential\'")
206 }
207 }
208}
209
210impl<F: Float> CorrelationModel<F> for AbsoluteExponentialCorr {
211 fn rval_from_distances(
215 &self,
216 d: &ArrayBase<impl Data<Elem = F>, Ix2>,
217 theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
218 weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
219 ) -> Array2<F> {
220 let theta_w = weights.mapv(|v| v.abs()).dot(theta);
221 let r = d.mapv(|v| v.abs()).dot(&theta_w);
222 r.mapv(|v| F::exp(-v))
223 .into_shape_with_order((d.nrows(), 1))
224 .unwrap()
225 }
226
227 fn jac(
228 &self,
229 x: &ArrayBase<impl Data<Elem = F>, Ix1>,
230 xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
231 theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
232 weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
233 ) -> Array2<F> {
234 let d = differences(x, xtrain);
235 let r = self.rval_from_distances(&d, theta, weights);
236 let sign_d = d.mapv(|v| v.signum());
237
238 let dtheta_w = sign_d
239 * (theta * weights.mapv(|v| v.abs()))
240 .sum_axis(Axis(1))
241 .mapv(|v| F::cast(-1.) * v);
242 &dtheta_w * &r
243 }
244
245 fn rval_with_jac(
246 &self,
247 x: &ArrayBase<impl Data<Elem = F>, Ix1>,
248 xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
249 theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
250 weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
251 ) -> (Array2<F>, Array2<F>) {
252 let d = differences(x, xtrain);
253 let neg_theta_w = (theta * weights.mapv(|v| v.abs()))
254 .sum_axis(Axis(1))
255 .mapv(|v| -v);
256 let r = {
257 let exponent = d.mapv(|v| v.abs()).dot(&neg_theta_w.mapv(|v| -v));
258 exponent
259 .mapv(|v| F::exp(-v))
260 .into_shape_with_order((d.nrows(), 1))
261 .unwrap()
262 };
263 let jr = &(d.mapv(|v| v.signum()) * &neg_theta_w) * &r;
264 (r, jr)
265 }
266
267 fn theta_influence_factors(&self) -> (F, F) {
268 (F::cast(0.15), F::cast(3.76))
269 }
270}
271
272impl fmt::Display for AbsoluteExponentialCorr {
273 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
274 write!(f, "AbsoluteExponential")
275 }
276}
277
278#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
280#[cfg_attr(
281 feature = "serializable",
282 derive(Serialize, Deserialize),
283 serde(into = "String"),
284 serde(try_from = "String")
285)]
286pub struct Matern32Corr();
287
288impl From<Matern32Corr> for String {
289 fn from(_item: Matern32Corr) -> String {
290 "Matern32".to_string()
291 }
292}
293
294impl TryFrom<String> for Matern32Corr {
295 type Error = &'static str;
296 fn try_from(s: String) -> Result<Self, Self::Error> {
297 if s == "Matern32" {
298 Ok(Self::default())
299 } else {
300 Err("Bad string value for Matern32Corr, should be \'Matern32\'")
301 }
302 }
303}
304
305impl<F: Float> CorrelationModel<F> for Matern32Corr {
306 fn rval_from_distances(
310 &self,
311 d: &ArrayBase<impl Data<Elem = F>, Ix2>,
312 theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
313 weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
314 ) -> Array2<F> {
315 let sqrt3 = F::cast(3.).sqrt();
316 let theta_w = theta * weights.mapv(|v| v.abs());
317
318 let mut r = Array1::zeros(d.nrows());
319 Zip::from(&mut r).and(d.rows()).for_each(|r_i, d_i| {
320 let mut a = F::one();
321 let mut b_sum = F::zero();
322 Zip::from(&d_i).and(theta_w.rows()).for_each(|&d_ij, tw_j| {
323 let abs_d = d_ij.abs();
324 let mut prod = F::one();
325 for &tw in tw_j.iter() {
326 prod *= F::one() + sqrt3 * tw * abs_d;
327 b_sum += tw * abs_d;
328 }
329 a *= prod;
330 });
331 *r_i = a * F::exp(-sqrt3 * b_sum);
332 });
333 r.into_shape_with_order((d.nrows(), 1)).unwrap()
334 }
335
336 fn jac(
337 &self,
338 x: &ArrayBase<impl Data<Elem = F>, Ix1>,
339 xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
340 theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
341 weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
342 ) -> Array2<F> {
343 let d = differences(x, xtrain);
344 let r = self.rval_from_distances(&d, theta, weights);
345 self._jac_from_r(&d, &r, theta, weights)
346 }
347
348 fn rval_with_jac(
349 &self,
350 x: &ArrayBase<impl Data<Elem = F>, Ix1>,
351 xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
352 theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
353 weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
354 ) -> (Array2<F>, Array2<F>) {
355 let d = differences(x, xtrain);
356 let r = self.rval_from_distances(&d, theta, weights);
357 let jr = self._jac_from_r(&d, &r, theta, weights);
358 (r, jr)
359 }
360
361 fn theta_influence_factors(&self) -> (F, F) {
362 (F::cast(0.21), F::cast(2.74))
363 }
364}
365
366impl fmt::Display for Matern32Corr {
367 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
368 write!(f, "Matern32")
369 }
370}
371
372impl Matern32Corr {
373 fn _jac_from_r<F: Float>(
379 &self,
380 d: &ArrayBase<impl Data<Elem = F>, Ix2>,
381 r: &ArrayBase<impl Data<Elem = F>, Ix2>,
382 theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
383 weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
384 ) -> Array2<F> {
385 let three = F::cast(3.);
386 let sqrt3 = three.sqrt();
387 let neg3 = F::cast(-3.);
388 let theta_w = theta * weights.mapv(|v| v.abs());
389
390 let mut jr = Array2::zeros((d.nrows(), d.ncols()));
391 Zip::from(jr.rows_mut())
392 .and(d.rows())
393 .and(r.column(0))
394 .for_each(|mut jr_i, d_i, &r_i| {
395 Zip::from(&mut jr_i).and(&d_i).and(theta_w.rows()).for_each(
396 |jr_ij, &d_ij, tw_j| {
397 let abs_d = d_ij.abs();
398 let sign_d = d_ij.signum();
399 let mut sum = F::zero();
400 for &tw in tw_j.iter() {
401 let f = F::one() + sqrt3 * tw * abs_d;
402 sum += tw * tw * abs_d / f;
403 }
404 *jr_ij = neg3 * sign_d * r_i * sum;
405 },
406 );
407 });
408 jr
409 }
410}
411
412#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
414#[cfg_attr(
415 feature = "serializable",
416 derive(Serialize, Deserialize),
417 serde(into = "String"),
418 serde(try_from = "String")
419)]
420pub struct Matern52Corr();
421
422impl From<Matern52Corr> for String {
423 fn from(_item: Matern52Corr) -> String {
424 "Matern52".to_string()
425 }
426}
427
428impl TryFrom<String> for Matern52Corr {
429 type Error = &'static str;
430 fn try_from(s: String) -> Result<Self, Self::Error> {
431 if s == "Matern52" {
432 Ok(Self::default())
433 } else {
434 Err("Bad string value for Matern52Corr, should be \'Matern52\'")
435 }
436 }
437}
438
439impl<F: Float> CorrelationModel<F> for Matern52Corr {
440 fn rval_from_distances(
444 &self,
445 d: &ArrayBase<impl Data<Elem = F>, Ix2>,
446 theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
447 weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
448 ) -> Array2<F> {
449 let sqrt5 = F::cast(5.).sqrt();
450 let div5_3 = F::cast(5. / 3.);
451 let theta_w = theta * weights.mapv(|v| v.abs());
452
453 let mut r = Array1::zeros(d.nrows());
454 Zip::from(&mut r).and(d.rows()).for_each(|r_i, d_i| {
455 let mut a = F::one();
456 let mut b_sum = F::zero();
457 Zip::from(&d_i).and(theta_w.rows()).for_each(|&d_ij, tw_j| {
458 let abs_d = d_ij.abs();
459 let mut prod = F::one();
460 for &tw in tw_j.iter() {
461 let u = tw * abs_d;
462 prod *= F::one() + sqrt5 * u + div5_3 * u * u;
463 b_sum += tw * abs_d;
464 }
465 a *= prod;
466 });
467 *r_i = a * F::exp(-sqrt5 * b_sum);
468 });
469 r.into_shape_with_order((d.nrows(), 1)).unwrap()
470 }
471
472 fn jac(
473 &self,
474 x: &ArrayBase<impl Data<Elem = F>, Ix1>,
475 xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
476 theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
477 weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
478 ) -> Array2<F> {
479 let d = differences(x, xtrain);
480 let r = self.rval_from_distances(&d, theta, weights);
481 self._jac_from_r(&d, &r, theta, weights)
482 }
483
484 fn rval_with_jac(
485 &self,
486 x: &ArrayBase<impl Data<Elem = F>, Ix1>,
487 xtrain: &ArrayBase<impl Data<Elem = F>, Ix2>,
488 theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
489 weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
490 ) -> (Array2<F>, Array2<F>) {
491 let d = differences(x, xtrain);
492 let r = self.rval_from_distances(&d, theta, weights);
493 let jr = self._jac_from_r(&d, &r, theta, weights);
494 (r, jr)
495 }
496
497 fn theta_influence_factors(&self) -> (F, F) {
498 (F::cast(0.23), F::cast(2.44))
499 }
500}
501
502impl fmt::Display for Matern52Corr {
503 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
504 write!(f, "Matern52")
505 }
506}
507
508impl Matern52Corr {
509 fn _jac_from_r<F: Float>(
517 &self,
518 d: &ArrayBase<impl Data<Elem = F>, Ix2>,
519 r: &ArrayBase<impl Data<Elem = F>, Ix2>,
520 theta: &ArrayBase<impl Data<Elem = F>, Ix1>,
521 weights: &ArrayBase<impl Data<Elem = F>, Ix2>,
522 ) -> Array2<F> {
523 let sqrt5 = F::cast(5.).sqrt();
524 let div5_3 = F::cast(5. / 3.);
525 let neg5_3 = F::cast(-5. / 3.);
526 let theta_w = theta * weights.mapv(|v| v.abs());
527
528 let mut jr = Array2::zeros((d.nrows(), d.ncols()));
529 Zip::from(jr.rows_mut())
530 .and(d.rows())
531 .and(r.column(0))
532 .for_each(|mut jr_i, d_i, &r_i| {
533 Zip::from(&mut jr_i).and(&d_i).and(theta_w.rows()).for_each(
534 |jr_ij, &d_ij, tw_j| {
535 let abs_d = d_ij.abs();
536 let sign_d = d_ij.signum();
537 let mut sum = F::zero();
538 for &tw in tw_j.iter() {
539 let u = tw * abs_d;
540 let f = F::one() + sqrt5 * u + div5_3 * u * u;
541 sum += tw * tw * abs_d * (F::one() + sqrt5 * u) / f;
542 }
543 *jr_ij = neg5_3 * sign_d * r_i * sum;
544 },
545 );
546 });
547 jr
548 }
549}
550
551#[cfg(test)]
552mod tests {
553 use super::*;
554 use crate::utils::{DiffMatrix, NormalizedData};
555 use approx::assert_abs_diff_eq;
556 use ndarray::{arr1, array};
557 use paste::paste;
558
559 #[test]
560 fn test_squared_exponential() {
561 let xt = array![[4.5], [1.2], [2.0], [3.0], [4.0]];
562 let dm = DiffMatrix::new(&xt);
563 let res = SquaredExponentialCorr::default().rval_from_distances(
564 &dm.d,
565 &arr1(&[f64::sqrt(0.2)]),
566 &array![[1.]],
567 );
568 let expected = array![
569 [0.336552878364737],
570 [0.5352614285189903],
571 [0.7985162187593771],
572 [0.9753099120283326],
573 [0.9380049995307295],
574 [0.7232502423798424],
575 [0.4565760496233148],
576 [0.9048374180359595],
577 [0.6703200460356393],
578 [0.9048374180359595]
579 ];
580 assert_abs_diff_eq!(res, expected, epsilon = 1e-6);
581 }
582
583 #[test]
584 fn test_squared_exponential_2d() {
585 let xt = array![[0., 1.], [2., 3.], [4., 5.]];
586 let dm = DiffMatrix::new(&xt);
587 dbg!(&dm);
588 let res = SquaredExponentialCorr::default().rval_from_distances(
589 &dm.d,
590 &arr1(&[f64::sqrt(2.), 2.]),
591 &array![[1., 0.], [0., 1.]],
592 );
593 let expected = array![[6.14421235e-06], [1.42516408e-21], [6.14421235e-06]];
594 assert_abs_diff_eq!(res, expected, epsilon = 1e-6);
595 }
596
597 #[test]
598 fn test_matern32_2d() {
599 let xt = array![[0., 1.], [2., 3.], [4., 5.]];
600 let dm = DiffMatrix::new(&xt);
601 dbg!(&dm);
602 let res = Matern32Corr::default().rval_from_distances(
603 &dm.d,
604 &arr1(&[1., 2.]),
605 &array![[1., 0.], [0., 1.]],
606 );
607 let expected = array![[1.08539595e-03], [1.10776401e-07], [1.08539595e-03]];
608 assert_abs_diff_eq!(res, expected, epsilon = 1e-6);
609 }
610
611 macro_rules! test_correlation {
612 ($corr:ident, $kpls:expr_2021) => {
613 paste! {
614 #[test]
615 fn [<test_corr_ $corr:lower _kpls_ $kpls _derivatives>]() {
616 let x = array![3., 5.];
617 let xt = array![
618 [-9.375, -5.625],
619 [-5.625, -4.375],
620 [9.375, 1.875],
621 [8.125, 5.625],
622 [-4.375, -0.625],
623 [6.875, -3.125],
624 [4.375, 9.375],
625 [3.125, 4.375],
626 [5.625, -8.125],
627 [-8.125, 3.125],
628 [1.875, -6.875],
629 [-0.625, 8.125],
630 [-1.875, -1.875],
631 [0.625, 0.625],
632 [-6.875, -9.375],
633 [-3.125, 6.875]
634 ];
635 let xtrain = NormalizedData::new(&xt);
636 let xnorm = (x.to_owned() - &xtrain.mean) / &xtrain.std;
637 let (theta, weights) = if $kpls {
638 (array![0.31059002],
639 array![[-0.02701716],
640 [-0.99963497]])
641 } else {
642 (array![0.34599115925909146, 0.32083374253611624],
643 array![[1., 0.], [0., 1.]])
644 };
645
646 let corr = [< $corr Corr >]::default();
647 let jac = corr.jac(&xnorm, &xtrain.data, &theta, &weights) / &xtrain.std;
648 println!("Jacobian: \n{:?}", jac);
649 let xa: f64 = x[0];
650 let xb: f64 = x[1];
651 let e = 1e-5;
652 let x = array![
653 [xa, xb],
654 [xa + e, xb],
655 [xa - e, xb],
656 [xa, xb + e],
657 [xa, xb - e]
658 ];
659
660 let mut rxx = Array2::zeros((xtrain.data.nrows(), x.nrows()));
661 Zip::from(rxx.columns_mut())
662 .and(x.rows())
663 .for_each(|mut rxxi, xi| {
664 let xnorm = (xi.to_owned() - &xtrain.mean) / &xtrain.std;
665 let d = differences(&xnorm, &xtrain.data);
666 rxxi.assign(&(corr.rval_from_distances( &d, &theta, &weights).column(0)));
667 });
668 let fdiffa = (rxx.column(1).to_owned() - rxx.column(2)).mapv(|v| v / (2. * e));
669 assert_abs_diff_eq!(fdiffa, jac.column(0), epsilon=1e-6);
670 let fdiffb = (rxx.column(3).to_owned() - rxx.column(4)).mapv(|v| v / (2. * e));
671 assert_abs_diff_eq!(fdiffb, jac.column(1), epsilon=1e-6);
672 }
673 }
674 };
675 }
676
677 test_correlation!(SquaredExponential, false);
678 test_correlation!(AbsoluteExponential, false);
679 test_correlation!(Matern32, false);
680 test_correlation!(Matern52, false);
681 test_correlation!(SquaredExponential, true);
682 test_correlation!(AbsoluteExponential, true);
683 test_correlation!(Matern32, true);
684 test_correlation!(Matern52, true);
685
686 #[test]
687 fn test_matern52_2d() {
688 let xt = array![[0., 1.], [2., 3.], [4., 5.]];
689 let dm = DiffMatrix::new(&xt);
690 let res = Matern52Corr::default().rval_from_distances(
691 &dm.d,
692 &arr1(&[1., 2.]),
693 &array![[1., 0.], [0., 1.]],
694 );
695 let expected = array![[6.62391590e-04], [1.02117882e-08], [6.62391590e-04]];
696 assert_abs_diff_eq!(res, expected, epsilon = 1e-6);
697 }
698}