Skip to main content

scirs2_optimize/surrogate/
kriging.rs

1//! Kriging (Gaussian Process) Surrogate Model
2//!
3//! Kriging is a spatial interpolation method that provides not only predictions
4//! but also prediction uncertainty (variance). It is the foundation of
5//! Bayesian optimization and is well-suited for expensive black-box optimization.
6//!
7//! ## Features
8//!
9//! - Multiple correlation functions (Gaussian, Matern, exponential)
10//! - Nugget parameter for handling noisy evaluations
11//! - Maximum Likelihood Estimation (MLE) of hyperparameters
12//! - Analytical prediction variance
13//!
14//! ## References
15//!
16//! - Sacks, J., Welch, W.J., Mitchell, T.J., Wynn, H.P. (1989).
17//!   Design and Analysis of Computer Experiments.
18//! - Rasmussen, C.E. & Williams, C.K.I. (2006).
19//!   Gaussian Processes for Machine Learning.
20
21use super::{solve_general, SurrogateModel};
22use crate::error::{OptimizeError, OptimizeResult};
23use scirs2_core::ndarray::{Array1, Array2};
24use scirs2_core::random::rngs::StdRng;
25use scirs2_core::random::{Rng, SeedableRng};
26use scirs2_core::RngExt;
27
28/// Correlation function for Kriging
29#[derive(Debug, Clone, Copy, PartialEq)]
30pub enum CorrelationFunction {
31    /// Squared exponential (Gaussian): exp(-sum(theta_k * |x_k - y_k|^2))
32    SquaredExponential,
33    /// Matern 3/2: (1 + sqrt(3)*r) * exp(-sqrt(3)*r)
34    Matern32,
35    /// Matern 5/2: (1 + sqrt(5)*r + 5/3*r^2) * exp(-sqrt(5)*r)
36    Matern52,
37    /// Exponential (Matern 1/2): exp(-r)
38    Exponential,
39    /// Power exponential: exp(-sum(theta_k * |x_k - y_k|^p))
40    PowerExponential {
41        /// Smoothness parameter (1 = exponential, 2 = Gaussian)
42        p: f64,
43    },
44}
45
46impl Default for CorrelationFunction {
47    fn default() -> Self {
48        CorrelationFunction::SquaredExponential
49    }
50}
51
52/// Options for Kriging surrogate
53#[derive(Debug, Clone)]
54pub struct KrigingOptions {
55    /// Correlation function to use
56    pub correlation: CorrelationFunction,
57    /// Nugget parameter (regularization / noise variance)
58    /// If None, will be estimated from data
59    pub nugget: Option<f64>,
60    /// Whether to estimate hyperparameters via MLE
61    pub optimize_hyperparams: bool,
62    /// Number of random restarts for hyperparameter optimization
63    pub n_restarts: usize,
64    /// Random seed for reproducibility
65    pub seed: Option<u64>,
66    /// Initial length-scale parameters (theta). If None, uses heuristic.
67    pub initial_theta: Option<Vec<f64>>,
68    /// Lower bound for theta
69    pub theta_lower: f64,
70    /// Upper bound for theta
71    pub theta_upper: f64,
72}
73
74impl Default for KrigingOptions {
75    fn default() -> Self {
76        Self {
77            correlation: CorrelationFunction::default(),
78            nugget: Some(1e-6),
79            optimize_hyperparams: true,
80            n_restarts: 5,
81            seed: None,
82            initial_theta: None,
83            theta_lower: 1e-3,
84            theta_upper: 1e3,
85        }
86    }
87}
88
89/// Kriging (Gaussian Process) Surrogate Model
90pub struct KrigingSurrogate {
91    options: KrigingOptions,
92    /// Training points (normalized)
93    x_train: Option<Array2<f64>>,
94    /// Training values (normalized)
95    y_train: Option<Array1<f64>>,
96    /// Estimated length-scale parameters
97    theta: Option<Vec<f64>>,
98    /// Estimated nugget
99    nugget: f64,
100    /// Kriging weights (alpha = R^{-1} * (y - mu))
101    alpha: Option<Array1<f64>>,
102    /// Estimated mean (trend)
103    mu: f64,
104    /// Estimated process variance
105    sigma_sq: f64,
106    /// Correlation matrix (R)
107    corr_matrix: Option<Array2<f64>>,
108    /// Cholesky factor of R (lower triangular)
109    chol_factor: Option<Array2<f64>>,
110    /// Normalization parameters
111    x_min: Option<Array1<f64>>,
112    x_range: Option<Array1<f64>>,
113    y_mean: f64,
114    y_std: f64,
115}
116
117impl KrigingSurrogate {
118    /// Create a new Kriging surrogate
119    pub fn new(options: KrigingOptions) -> Self {
120        let nugget = options.nugget.unwrap_or(1e-6);
121        Self {
122            options,
123            x_train: None,
124            y_train: None,
125            theta: None,
126            nugget,
127            alpha: None,
128            mu: 0.0,
129            sigma_sq: 1.0,
130            corr_matrix: None,
131            chol_factor: None,
132            x_min: None,
133            x_range: None,
134            y_mean: 0.0,
135            y_std: 1.0,
136        }
137    }
138
139    /// Compute correlation between two points given theta
140    fn correlation(&self, x1: &[f64], x2: &[f64], theta: &[f64]) -> f64 {
141        let d = x1.len();
142        match self.options.correlation {
143            CorrelationFunction::SquaredExponential => {
144                let mut sum = 0.0;
145                for k in 0..d {
146                    let diff = x1[k] - x2[k];
147                    sum += theta[k.min(theta.len() - 1)] * diff * diff;
148                }
149                (-sum).exp()
150            }
151            CorrelationFunction::Matern32 => {
152                let mut weighted_sq_sum = 0.0;
153                for k in 0..d {
154                    let diff = x1[k] - x2[k];
155                    weighted_sq_sum += theta[k.min(theta.len() - 1)] * diff * diff;
156                }
157                let r = (3.0 * weighted_sq_sum).sqrt();
158                (1.0 + r) * (-r).exp()
159            }
160            CorrelationFunction::Matern52 => {
161                let mut weighted_sq_sum = 0.0;
162                for k in 0..d {
163                    let diff = x1[k] - x2[k];
164                    weighted_sq_sum += theta[k.min(theta.len() - 1)] * diff * diff;
165                }
166                let r = (5.0 * weighted_sq_sum).sqrt();
167                (1.0 + r + r * r / 3.0) * (-r).exp()
168            }
169            CorrelationFunction::Exponential => {
170                let mut sum = 0.0;
171                for k in 0..d {
172                    let diff = (x1[k] - x2[k]).abs();
173                    sum += theta[k.min(theta.len() - 1)] * diff;
174                }
175                (-sum).exp()
176            }
177            CorrelationFunction::PowerExponential { p } => {
178                let mut sum = 0.0;
179                for k in 0..d {
180                    let diff = (x1[k] - x2[k]).abs();
181                    sum += theta[k.min(theta.len() - 1)] * diff.powf(p);
182                }
183                (-sum).exp()
184            }
185        }
186    }
187
188    /// Compute the correlation matrix for given points
189    fn compute_correlation_matrix(
190        &self,
191        x: &Array2<f64>,
192        theta: &[f64],
193        nugget: f64,
194    ) -> Array2<f64> {
195        let n = x.nrows();
196        let mut r = Array2::zeros((n, n));
197        for i in 0..n {
198            r[[i, i]] = 1.0 + nugget;
199            let x_i: Vec<f64> = (0..x.ncols()).map(|k| x[[i, k]]).collect();
200            for j in (i + 1)..n {
201                let x_j: Vec<f64> = (0..x.ncols()).map(|k| x[[j, k]]).collect();
202                let c = self.correlation(&x_i, &x_j, theta);
203                r[[i, j]] = c;
204                r[[j, i]] = c;
205            }
206        }
207        r
208    }
209
210    /// Compute correlation vector between a point and training data
211    fn compute_correlation_vector(
212        &self,
213        x: &[f64],
214        x_train: &Array2<f64>,
215        theta: &[f64],
216    ) -> Array1<f64> {
217        let n = x_train.nrows();
218        let mut r = Array1::zeros(n);
219        for i in 0..n {
220            let x_i: Vec<f64> = (0..x_train.ncols()).map(|k| x_train[[i, k]]).collect();
221            r[i] = self.correlation(x, &x_i, theta);
222        }
223        r
224    }
225
226    /// Cholesky decomposition
227    fn cholesky(&self, a: &Array2<f64>) -> OptimizeResult<Array2<f64>> {
228        let n = a.nrows();
229        let mut l = Array2::zeros((n, n));
230        for j in 0..n {
231            let mut sum = 0.0;
232            for k in 0..j {
233                sum += l[[j, k]] * l[[j, k]];
234            }
235            let diag = a[[j, j]] - sum;
236            if diag <= 0.0 {
237                return Err(OptimizeError::ComputationError(
238                    "Correlation matrix is not positive definite".to_string(),
239                ));
240            }
241            l[[j, j]] = diag.sqrt();
242            for i in (j + 1)..n {
243                let mut sum = 0.0;
244                for k in 0..j {
245                    sum += l[[i, k]] * l[[j, k]];
246                }
247                l[[i, j]] = (a[[i, j]] - sum) / l[[j, j]];
248            }
249        }
250        Ok(l)
251    }
252
253    /// Solve L * x = b (forward substitution)
254    fn solve_lower(&self, l: &Array2<f64>, b: &Array1<f64>) -> Array1<f64> {
255        let n = b.len();
256        let mut x = Array1::zeros(n);
257        for i in 0..n {
258            let mut sum = 0.0;
259            for j in 0..i {
260                sum += l[[i, j]] * x[j];
261            }
262            x[i] = (b[i] - sum) / l[[i, i]];
263        }
264        x
265    }
266
267    /// Solve L^T * x = b (back substitution)
268    fn solve_upper(&self, l: &Array2<f64>, b: &Array1<f64>) -> Array1<f64> {
269        let n = b.len();
270        let mut x = Array1::zeros(n);
271        for i in (0..n).rev() {
272            let mut sum = 0.0;
273            for j in (i + 1)..n {
274                sum += l[[j, i]] * x[j];
275            }
276            x[i] = (b[i] - sum) / l[[i, i]];
277        }
278        x
279    }
280
281    /// Compute concentrated log-likelihood for given theta
282    fn log_likelihood(
283        &self,
284        x_train: &Array2<f64>,
285        y_train: &Array1<f64>,
286        theta: &[f64],
287        nugget: f64,
288    ) -> f64 {
289        let n = x_train.nrows();
290        let r = self.compute_correlation_matrix(x_train, theta, nugget);
291
292        let chol = match self.cholesky(&r) {
293            Ok(l) => l,
294            Err(_) => return f64::NEG_INFINITY,
295        };
296
297        // Compute log determinant
298        let log_det: f64 = (0..n).map(|i| chol[[i, i]].ln()).sum::<f64>() * 2.0;
299
300        // Solve R * ones = r_ones  to get mu
301        let ones = Array1::ones(n);
302        let z = self.solve_lower(&chol, &ones);
303        let r_inv_ones = self.solve_upper(&chol, &z);
304        let ones_r_inv_ones: f64 = ones.dot(&r_inv_ones);
305
306        if ones_r_inv_ones.abs() < 1e-30 {
307            return f64::NEG_INFINITY;
308        }
309
310        // Solve R * y_solve = y
311        let z_y = self.solve_lower(&chol, y_train);
312        let r_inv_y = self.solve_upper(&chol, &z_y);
313
314        let mu_hat = ones.dot(&r_inv_y) / ones_r_inv_ones;
315
316        // Compute sigma^2
317        let residual: Array1<f64> = y_train - mu_hat;
318        let z_res = self.solve_lower(&chol, &residual);
319        let r_inv_res = self.solve_upper(&chol, &z_res);
320        let sigma_sq = residual.dot(&r_inv_res) / n as f64;
321
322        if sigma_sq <= 0.0 {
323            return f64::NEG_INFINITY;
324        }
325
326        // Concentrated log-likelihood
327        -0.5 * (n as f64 * sigma_sq.ln() + log_det)
328    }
329
330    /// Optimize hyperparameters using random search with local refinement
331    fn optimize_hyperparameters(
332        &self,
333        x_train: &Array2<f64>,
334        y_train: &Array1<f64>,
335    ) -> (Vec<f64>, f64) {
336        let d = x_train.ncols();
337        let seed = self
338            .options
339            .seed
340            .unwrap_or_else(|| scirs2_core::random::rng().random_range(0..u64::MAX));
341        let mut rng = StdRng::seed_from_u64(seed);
342
343        let theta_lo = self.options.theta_lower;
344        let theta_hi = self.options.theta_upper;
345        let log_lo = theta_lo.ln();
346        let log_hi = theta_hi.ln();
347
348        let nugget = self.nugget;
349
350        // Initial theta
351        let mut best_theta: Vec<f64> = self
352            .options
353            .initial_theta
354            .clone()
355            .unwrap_or_else(|| vec![1.0; d]);
356
357        let mut best_ll = self.log_likelihood(x_train, y_train, &best_theta, nugget);
358
359        // Random restarts
360        for _ in 0..self.options.n_restarts {
361            let theta: Vec<f64> = (0..d)
362                .map(|_| rng.random_range(log_lo..log_hi).exp())
363                .collect();
364
365            let ll = self.log_likelihood(x_train, y_train, &theta, nugget);
366            if ll > best_ll {
367                best_ll = ll;
368                best_theta = theta;
369            }
370        }
371
372        // Local refinement via coordinate-wise line search
373        for _ in 0..3 {
374            for k in 0..d {
375                let mut best_tk = best_theta[k];
376                let mut best_ll_k = best_ll;
377
378                for &factor in &[0.1, 0.3, 0.5, 0.7, 1.5, 2.0, 3.0, 10.0] {
379                    let mut trial = best_theta.clone();
380                    trial[k] = (best_theta[k] * factor).clamp(theta_lo, theta_hi);
381                    let ll = self.log_likelihood(x_train, y_train, &trial, nugget);
382                    if ll > best_ll_k {
383                        best_ll_k = ll;
384                        best_tk = trial[k];
385                    }
386                }
387
388                best_theta[k] = best_tk;
389                best_ll = best_ll_k;
390            }
391        }
392
393        (best_theta, nugget)
394    }
395
396    /// Normalize X to [0, 1]
397    fn normalize_x(&self, x: &Array2<f64>) -> Array2<f64> {
398        if let (Some(ref x_min), Some(ref x_range)) = (&self.x_min, &self.x_range) {
399            let mut normalized = x.clone();
400            for i in 0..x.nrows() {
401                for j in 0..x.ncols() {
402                    let r = if x_range[j] > 1e-30 { x_range[j] } else { 1.0 };
403                    normalized[[i, j]] = (x[[i, j]] - x_min[j]) / r;
404                }
405            }
406            normalized
407        } else {
408            x.clone()
409        }
410    }
411
412    /// Normalize a single x point
413    fn normalize_x_point(&self, x: &Array1<f64>) -> Vec<f64> {
414        if let (Some(ref x_min), Some(ref x_range)) = (&self.x_min, &self.x_range) {
415            x.iter()
416                .enumerate()
417                .map(|(j, &xj)| {
418                    let r = if x_range[j] > 1e-30 { x_range[j] } else { 1.0 };
419                    (xj - x_min[j]) / r
420                })
421                .collect()
422        } else {
423            x.to_vec()
424        }
425    }
426}
427
428impl SurrogateModel for KrigingSurrogate {
429    fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> OptimizeResult<()> {
430        let n = x.nrows();
431        let d = x.ncols();
432
433        if n < 2 {
434            return Err(OptimizeError::InvalidInput(
435                "Need at least 2 data points for Kriging".to_string(),
436            ));
437        }
438
439        // Compute normalization
440        let mut x_min = Array1::zeros(d);
441        let mut x_max = Array1::zeros(d);
442        for j in 0..d {
443            let mut lo = f64::INFINITY;
444            let mut hi = f64::NEG_INFINITY;
445            for i in 0..n {
446                if x[[i, j]] < lo {
447                    lo = x[[i, j]];
448                }
449                if x[[i, j]] > hi {
450                    hi = x[[i, j]];
451                }
452            }
453            x_min[j] = lo;
454            x_max[j] = hi;
455        }
456        let x_range = &x_max - &x_min;
457        self.x_min = Some(x_min);
458        self.x_range = Some(x_range);
459
460        let y_sum: f64 = y.iter().sum();
461        self.y_mean = y_sum / n as f64;
462        let y_var: f64 = y.iter().map(|yi| (yi - self.y_mean).powi(2)).sum::<f64>() / n as f64;
463        self.y_std = y_var.sqrt().max(1e-30);
464
465        // Normalize
466        let x_norm = self.normalize_x(x);
467        let y_norm: Array1<f64> = y.mapv(|yi| (yi - self.y_mean) / self.y_std);
468
469        // Optimize hyperparameters
470        let (theta, nugget) = if self.options.optimize_hyperparams {
471            self.optimize_hyperparameters(&x_norm, &y_norm)
472        } else {
473            let theta = self
474                .options
475                .initial_theta
476                .clone()
477                .unwrap_or_else(|| vec![1.0; d]);
478            (theta, self.nugget)
479        };
480        self.theta = Some(theta.clone());
481        self.nugget = nugget;
482
483        // Build correlation matrix
484        let r = self.compute_correlation_matrix(&x_norm, &theta, nugget);
485        let chol = self.cholesky(&r)?;
486
487        // Estimate mu
488        let ones = Array1::ones(n);
489        let z = self.solve_lower(&chol, &ones);
490        let r_inv_ones = self.solve_upper(&chol, &z);
491        let ones_r_inv_ones = ones.dot(&r_inv_ones);
492
493        let z_y = self.solve_lower(&chol, &y_norm);
494        let r_inv_y = self.solve_upper(&chol, &z_y);
495
496        self.mu = if ones_r_inv_ones.abs() > 1e-30 {
497            ones.dot(&r_inv_y) / ones_r_inv_ones
498        } else {
499            y_norm.mean().unwrap_or(0.0)
500        };
501
502        // Compute alpha = R^{-1} * (y - mu)
503        let residual: Array1<f64> = &y_norm - self.mu;
504        let z_res = self.solve_lower(&chol, &residual);
505        let alpha = self.solve_upper(&chol, &z_res);
506
507        // Estimate sigma^2
508        self.sigma_sq = (residual.dot(&alpha) / n as f64).max(1e-20);
509
510        self.alpha = Some(alpha);
511        self.corr_matrix = Some(r);
512        self.chol_factor = Some(chol);
513        self.x_train = Some(x_norm);
514        self.y_train = Some(y_norm);
515
516        Ok(())
517    }
518
519    fn predict(&self, x: &Array1<f64>) -> OptimizeResult<f64> {
520        let x_train = self
521            .x_train
522            .as_ref()
523            .ok_or_else(|| OptimizeError::ComputationError("Model not fitted".to_string()))?;
524        let alpha = self
525            .alpha
526            .as_ref()
527            .ok_or_else(|| OptimizeError::ComputationError("Model not fitted".to_string()))?;
528        let theta = self
529            .theta
530            .as_ref()
531            .ok_or_else(|| OptimizeError::ComputationError("Model not fitted".to_string()))?;
532
533        let x_norm = self.normalize_x_point(x);
534        let r = self.compute_correlation_vector(&x_norm, x_train, theta);
535
536        let prediction_norm = self.mu + r.dot(alpha);
537
538        // Denormalize
539        Ok(prediction_norm * self.y_std + self.y_mean)
540    }
541
542    fn predict_with_uncertainty(&self, x: &Array1<f64>) -> OptimizeResult<(f64, f64)> {
543        let x_train = self
544            .x_train
545            .as_ref()
546            .ok_or_else(|| OptimizeError::ComputationError("Model not fitted".to_string()))?;
547        let alpha = self
548            .alpha
549            .as_ref()
550            .ok_or_else(|| OptimizeError::ComputationError("Model not fitted".to_string()))?;
551        let theta = self
552            .theta
553            .as_ref()
554            .ok_or_else(|| OptimizeError::ComputationError("Model not fitted".to_string()))?;
555        let chol = self
556            .chol_factor
557            .as_ref()
558            .ok_or_else(|| OptimizeError::ComputationError("Model not fitted".to_string()))?;
559
560        let n = x_train.nrows();
561        let x_norm = self.normalize_x_point(x);
562        let r = self.compute_correlation_vector(&x_norm, x_train, theta);
563
564        // Mean prediction
565        let prediction_norm = self.mu + r.dot(alpha);
566        let mean = prediction_norm * self.y_std + self.y_mean;
567
568        // Prediction variance via Kriging equations
569        // s^2(x) = sigma^2 * (1 - r^T R^{-1} r + (1 - 1^T R^{-1} r)^2 / (1^T R^{-1} 1))
570        let z = self.solve_lower(chol, &r);
571        let rt_r_inv_r = z.dot(&z);
572
573        let ones = Array1::ones(n);
574        let z_ones = self.solve_lower(chol, &ones);
575        let ones_r_inv_r: f64 = z_ones.dot(&z);
576        let ones_r_inv_ones: f64 = z_ones.dot(&z_ones);
577
578        let numerator = (1.0 - ones_r_inv_r).powi(2);
579        let denominator = ones_r_inv_ones.max(1e-30);
580
581        let mse_norm = self.sigma_sq * (1.0 - rt_r_inv_r + numerator / denominator).max(0.0);
582        let std = (mse_norm * self.y_std * self.y_std).sqrt().max(1e-10);
583
584        Ok((mean, std))
585    }
586
587    fn n_samples(&self) -> usize {
588        self.x_train.as_ref().map_or(0, |x| x.nrows())
589    }
590
591    fn n_features(&self) -> usize {
592        self.x_train.as_ref().map_or(0, |x| x.ncols())
593    }
594
595    fn update(&mut self, x: &Array1<f64>, y: f64) -> OptimizeResult<()> {
596        // Refit with new data point
597        let (new_x, new_y) =
598            if let (Some(ref x_train), Some(ref y_train)) = (&self.x_train, &self.y_train) {
599                let d = x_train.ncols();
600                let n = x_train.nrows();
601
602                // Denormalize
603                let mut x_denorm = Array2::zeros((n, d));
604                for i in 0..n {
605                    for j in 0..d {
606                        let r = self.x_range.as_ref().map_or(1.0, |xr| {
607                            if xr[j] > 1e-30 {
608                                xr[j]
609                            } else {
610                                1.0
611                            }
612                        });
613                        let m = self.x_min.as_ref().map_or(0.0, |xm| xm[j]);
614                        x_denorm[[i, j]] = x_train[[i, j]] * r + m;
615                    }
616                }
617                let y_denorm: Array1<f64> = y_train.mapv(|yi| yi * self.y_std + self.y_mean);
618
619                let mut new_x = Array2::zeros((n + 1, d));
620                for i in 0..n {
621                    for j in 0..d {
622                        new_x[[i, j]] = x_denorm[[i, j]];
623                    }
624                }
625                for j in 0..d {
626                    new_x[[n, j]] = x[j];
627                }
628
629                let mut new_y = Array1::zeros(n + 1);
630                for i in 0..n {
631                    new_y[i] = y_denorm[i];
632                }
633                new_y[n] = y;
634
635                (new_x, new_y)
636            } else {
637                let d = x.len();
638                let mut new_x = Array2::zeros((1, d));
639                for j in 0..d {
640                    new_x[[0, j]] = x[j];
641                }
642                (new_x, Array1::from_vec(vec![y]))
643            };
644
645        self.fit(&new_x, &new_y)
646    }
647}
648
649#[cfg(test)]
650mod tests {
651    use super::*;
652
653    #[test]
654    fn test_kriging_basic_interpolation() {
655        let x_train = Array2::from_shape_vec((5, 1), vec![0.0, 0.25, 0.5, 0.75, 1.0])
656            .expect("Array creation failed");
657        let y_train = Array1::from_vec(vec![0.0, 0.25, 1.0, 0.75, 0.0]);
658
659        let mut kriging = KrigingSurrogate::new(KrigingOptions {
660            nugget: Some(1e-4),
661            optimize_hyperparams: false,
662            initial_theta: Some(vec![10.0]),
663            ..Default::default()
664        });
665
666        let result = kriging.fit(&x_train, &y_train);
667        assert!(result.is_ok(), "Kriging fit failed: {:?}", result.err());
668
669        // Predict at training points (should approximate closely)
670        for i in 0..5 {
671            let x = Array1::from_vec(vec![x_train[[i, 0]]]);
672            let pred = kriging.predict(&x).expect("Prediction failed");
673            assert!(
674                (pred - y_train[i]).abs() < 0.2,
675                "Kriging interpolation error at {}: pred={}, actual={}",
676                i,
677                pred,
678                y_train[i]
679            );
680        }
681    }
682
683    #[test]
684    fn test_kriging_uncertainty() {
685        let x_train = Array2::from_shape_vec((4, 1), vec![0.0, 0.33, 0.66, 1.0])
686            .expect("Array creation failed");
687        let y_train = Array1::from_vec(vec![0.0, 1.0, 1.0, 0.0]);
688
689        let mut kriging = KrigingSurrogate::new(KrigingOptions {
690            nugget: Some(1e-4),
691            optimize_hyperparams: false,
692            initial_theta: Some(vec![5.0]),
693            ..Default::default()
694        });
695        kriging.fit(&x_train, &y_train).expect("Fit failed");
696
697        // Uncertainty at a training point should be lower
698        let (_, unc_near) = kriging
699            .predict_with_uncertainty(&Array1::from_vec(vec![0.33]))
700            .expect("Prediction failed");
701        let (_, unc_far) = kriging
702            .predict_with_uncertainty(&Array1::from_vec(vec![2.0]))
703            .expect("Prediction failed");
704
705        assert!(
706            unc_far > unc_near,
707            "Far uncertainty ({}) should exceed near uncertainty ({})",
708            unc_far,
709            unc_near
710        );
711    }
712
713    #[test]
714    fn test_kriging_2d() {
715        let x_train = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
716            .expect("Array creation failed");
717        let y_train = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0]);
718
719        let mut kriging = KrigingSurrogate::new(KrigingOptions {
720            nugget: Some(1e-4),
721            n_restarts: 2,
722            ..Default::default()
723        });
724        assert!(kriging.fit(&x_train, &y_train).is_ok());
725
726        let pred = kriging.predict(&Array1::from_vec(vec![0.5, 0.5]));
727        assert!(pred.is_ok());
728        let val = pred.expect("2D prediction failed");
729        assert!(val > -1.0 && val < 3.0, "Kriging 2D prediction: {}", val);
730    }
731
732    #[test]
733    fn test_kriging_matern32() {
734        let x_train =
735            Array2::from_shape_vec((3, 1), vec![0.0, 0.5, 1.0]).expect("Array creation failed");
736        let y_train = Array1::from_vec(vec![1.0, 2.0, 1.0]);
737
738        let mut kriging = KrigingSurrogate::new(KrigingOptions {
739            correlation: CorrelationFunction::Matern32,
740            nugget: Some(1e-4),
741            optimize_hyperparams: false,
742            initial_theta: Some(vec![5.0]),
743            ..Default::default()
744        });
745        assert!(kriging.fit(&x_train, &y_train).is_ok());
746        let pred = kriging.predict(&Array1::from_vec(vec![0.25]));
747        assert!(pred.is_ok());
748    }
749
750    #[test]
751    fn test_kriging_matern52() {
752        let x_train =
753            Array2::from_shape_vec((3, 1), vec![0.0, 0.5, 1.0]).expect("Array creation failed");
754        let y_train = Array1::from_vec(vec![0.0, 1.0, 0.0]);
755
756        let mut kriging = KrigingSurrogate::new(KrigingOptions {
757            correlation: CorrelationFunction::Matern52,
758            nugget: Some(1e-4),
759            optimize_hyperparams: false,
760            initial_theta: Some(vec![5.0]),
761            ..Default::default()
762        });
763        assert!(kriging.fit(&x_train, &y_train).is_ok());
764    }
765
766    #[test]
767    fn test_kriging_exponential() {
768        let x_train =
769            Array2::from_shape_vec((3, 1), vec![0.0, 0.5, 1.0]).expect("Array creation failed");
770        let y_train = Array1::from_vec(vec![0.0, 1.0, 0.0]);
771
772        let mut kriging = KrigingSurrogate::new(KrigingOptions {
773            correlation: CorrelationFunction::Exponential,
774            nugget: Some(1e-3),
775            optimize_hyperparams: false,
776            initial_theta: Some(vec![5.0]),
777            ..Default::default()
778        });
779        assert!(kriging.fit(&x_train, &y_train).is_ok());
780    }
781
782    #[test]
783    fn test_kriging_update() {
784        let x_train =
785            Array2::from_shape_vec((3, 1), vec![0.0, 0.5, 1.0]).expect("Array creation failed");
786        let y_train = Array1::from_vec(vec![0.0, 1.0, 0.0]);
787
788        let mut kriging = KrigingSurrogate::new(KrigingOptions {
789            nugget: Some(1e-4),
790            optimize_hyperparams: false,
791            initial_theta: Some(vec![5.0]),
792            ..Default::default()
793        });
794        kriging.fit(&x_train, &y_train).expect("Fit failed");
795        assert_eq!(kriging.n_samples(), 3);
796
797        kriging
798            .update(&Array1::from_vec(vec![0.25]), 0.5)
799            .expect("Update failed");
800        assert_eq!(kriging.n_samples(), 4);
801    }
802
803    #[test]
804    fn test_kriging_power_exponential() {
805        let x_train =
806            Array2::from_shape_vec((3, 1), vec![0.0, 0.5, 1.0]).expect("Array creation failed");
807        let y_train = Array1::from_vec(vec![1.0, 0.5, 1.0]);
808
809        let mut kriging = KrigingSurrogate::new(KrigingOptions {
810            correlation: CorrelationFunction::PowerExponential { p: 1.5 },
811            nugget: Some(1e-3),
812            optimize_hyperparams: false,
813            initial_theta: Some(vec![5.0]),
814            ..Default::default()
815        });
816        assert!(kriging.fit(&x_train, &y_train).is_ok());
817    }
818}