Skip to main content

scirs2_optimize/surrogate/
ensemble.rs

1//! Ensemble of Surrogate Models
2//!
3//! This module provides an ensemble surrogate that combines multiple surrogate
4//! models to improve prediction accuracy and robustness. The ensemble
5//! automatically selects and weights models based on cross-validation performance.
6//!
7//! ## Features
8//!
9//! - Combines RBF, Kriging, and other surrogates
10//! - Automatic model weighting via cross-validation
11//! - Multiple model selection criteria (LOOCV, K-fold, AIC, BIC)
12//! - Hedge strategy for adaptive weight updates
13//!
14//! ## References
15//!
16//! - Viana, F.A.C., Haftka, R.T., Watson, L.T. (2009).
17//!   Efficient Global Optimization Algorithm Assisted by Multiple Surrogate Techniques.
18//! - Goel, T., Haftka, R.T., Shyy, W., Queipo, N.V. (2007).
19//!   Ensemble of Surrogates.
20
21use super::{
22    kriging::{CorrelationFunction, KrigingOptions, KrigingSurrogate},
23    rbf_surrogate::{RbfKernel, RbfOptions, RbfSurrogate},
24    SurrogateModel,
25};
26use crate::error::{OptimizeError, OptimizeResult};
27use scirs2_core::ndarray::{Array1, Array2};
28
29/// Model selection criterion for the ensemble
30#[derive(Debug, Clone, Copy, PartialEq)]
31pub enum ModelSelectionCriterion {
32    /// Leave-One-Out Cross-Validation (LOOCV)
33    Loocv,
34    /// K-Fold Cross-Validation
35    KFold {
36        /// Number of folds
37        k: usize,
38    },
39    /// Akaike Information Criterion (approximation)
40    Aic,
41    /// Equal weighting (all models contribute equally)
42    Equal,
43    /// Best single model (winner takes all)
44    BestSingle,
45}
46
47impl Default for ModelSelectionCriterion {
48    fn default() -> Self {
49        ModelSelectionCriterion::Loocv
50    }
51}
52
53/// Options for ensemble surrogate
54#[derive(Debug, Clone)]
55pub struct EnsembleOptions {
56    /// Model selection criterion
57    pub criterion: ModelSelectionCriterion,
58    /// Whether to include RBF with cubic kernel
59    pub include_rbf_cubic: bool,
60    /// Whether to include RBF with Gaussian kernel
61    pub include_rbf_gaussian: bool,
62    /// Whether to include RBF with multiquadric kernel
63    pub include_rbf_multiquadric: bool,
64    /// Whether to include RBF with thin-plate spline
65    pub include_rbf_tps: bool,
66    /// Whether to include Kriging with squared exponential
67    pub include_kriging_se: bool,
68    /// Whether to include Kriging with Matern 5/2
69    pub include_kriging_matern52: bool,
70    /// Minimum weight for a model to be included (pruning threshold)
71    pub min_weight: f64,
72    /// Random seed for cross-validation
73    pub seed: Option<u64>,
74}
75
76impl Default for EnsembleOptions {
77    fn default() -> Self {
78        Self {
79            criterion: ModelSelectionCriterion::default(),
80            include_rbf_cubic: true,
81            include_rbf_gaussian: true,
82            include_rbf_multiquadric: false,
83            include_rbf_tps: true,
84            include_kriging_se: true,
85            include_kriging_matern52: true,
86            min_weight: 0.01,
87            seed: None,
88        }
89    }
90}
91
92/// A member of the ensemble
93struct EnsembleMember {
94    /// The surrogate model
95    model: Box<dyn SurrogateModel>,
96    /// Name/label for the model
97    name: String,
98    /// Weight in the ensemble
99    weight: f64,
100}
101
102/// Ensemble Surrogate Model
103pub struct EnsembleSurrogate {
104    options: EnsembleOptions,
105    /// Ensemble members
106    members: Vec<EnsembleMember>,
107    /// Raw training data (kept for re-fitting)
108    x_train_raw: Option<Array2<f64>>,
109    y_train_raw: Option<Array1<f64>>,
110}
111
112impl EnsembleSurrogate {
113    /// Create a new ensemble surrogate
114    pub fn new(options: EnsembleOptions) -> Self {
115        Self {
116            options,
117            members: Vec::new(),
118            x_train_raw: None,
119            y_train_raw: None,
120        }
121    }
122
123    /// Create the ensemble members based on options
124    fn create_members(&self) -> Vec<(Box<dyn SurrogateModel>, String)> {
125        let mut members: Vec<(Box<dyn SurrogateModel>, String)> = Vec::new();
126
127        if self.options.include_rbf_cubic {
128            members.push((
129                Box::new(RbfSurrogate::new(RbfOptions {
130                    kernel: RbfKernel::Polyharmonic(3),
131                    regularization: 1e-8,
132                    normalize: true,
133                })),
134                "RBF-Cubic".to_string(),
135            ));
136        }
137
138        if self.options.include_rbf_gaussian {
139            members.push((
140                Box::new(RbfSurrogate::new(RbfOptions {
141                    kernel: RbfKernel::Gaussian { sigma: 1.0 },
142                    regularization: 1e-6,
143                    normalize: true,
144                })),
145                "RBF-Gaussian".to_string(),
146            ));
147        }
148
149        if self.options.include_rbf_multiquadric {
150            members.push((
151                Box::new(RbfSurrogate::new(RbfOptions {
152                    kernel: RbfKernel::Multiquadric { shape_param: 1.0 },
153                    regularization: 1e-8,
154                    normalize: true,
155                })),
156                "RBF-MQ".to_string(),
157            ));
158        }
159
160        if self.options.include_rbf_tps {
161            members.push((
162                Box::new(RbfSurrogate::new(RbfOptions {
163                    kernel: RbfKernel::ThinPlateSpline,
164                    regularization: 1e-8,
165                    normalize: true,
166                })),
167                "RBF-TPS".to_string(),
168            ));
169        }
170
171        if self.options.include_kriging_se {
172            members.push((
173                Box::new(KrigingSurrogate::new(KrigingOptions {
174                    correlation: CorrelationFunction::SquaredExponential,
175                    nugget: Some(1e-4),
176                    n_restarts: 3,
177                    seed: self.options.seed,
178                    ..Default::default()
179                })),
180                "Kriging-SE".to_string(),
181            ));
182        }
183
184        if self.options.include_kriging_matern52 {
185            members.push((
186                Box::new(KrigingSurrogate::new(KrigingOptions {
187                    correlation: CorrelationFunction::Matern52,
188                    nugget: Some(1e-4),
189                    n_restarts: 3,
190                    seed: self.options.seed,
191                    ..Default::default()
192                })),
193                "Kriging-Matern52".to_string(),
194            ));
195        }
196
197        members
198    }
199
200    /// Compute LOOCV error for a model
201    fn loocv_error(
202        &self,
203        model_factory: &dyn Fn() -> Box<dyn SurrogateModel>,
204        x: &Array2<f64>,
205        y: &Array1<f64>,
206    ) -> f64 {
207        let n = x.nrows();
208        let d = x.ncols();
209
210        if n < 3 {
211            return f64::INFINITY;
212        }
213
214        let mut total_sq_error = 0.0;
215        let mut valid_count = 0;
216
217        for leave_out in 0..n {
218            // Build training set without leave_out
219            let mut x_train = Array2::zeros((n - 1, d));
220            let mut y_train = Array1::zeros(n - 1);
221            let mut idx = 0;
222            for i in 0..n {
223                if i != leave_out {
224                    for j in 0..d {
225                        x_train[[idx, j]] = x[[i, j]];
226                    }
227                    y_train[idx] = y[i];
228                    idx += 1;
229                }
230            }
231
232            let mut model = model_factory();
233            if model.fit(&x_train, &y_train).is_ok() {
234                let x_test = x.row(leave_out).to_owned();
235                if let Ok(pred) = model.predict(&x_test) {
236                    let error = pred - y[leave_out];
237                    total_sq_error += error * error;
238                    valid_count += 1;
239                }
240            }
241        }
242
243        if valid_count > 0 {
244            total_sq_error / valid_count as f64
245        } else {
246            f64::INFINITY
247        }
248    }
249
250    /// Compute weights based on cross-validation errors
251    fn compute_weights(&self, cv_errors: &[f64]) -> Vec<f64> {
252        let n = cv_errors.len();
253        if n == 0 {
254            return Vec::new();
255        }
256
257        match self.options.criterion {
258            ModelSelectionCriterion::Equal => {
259                vec![1.0 / n as f64; n]
260            }
261            ModelSelectionCriterion::BestSingle => {
262                let mut weights = vec![0.0; n];
263                let mut best_idx = 0;
264                let mut best_err = f64::INFINITY;
265                for (i, &err) in cv_errors.iter().enumerate() {
266                    if err < best_err {
267                        best_err = err;
268                        best_idx = i;
269                    }
270                }
271                weights[best_idx] = 1.0;
272                weights
273            }
274            _ => {
275                // Weight inversely proportional to CV error
276                let min_err = cv_errors.iter().copied().fold(f64::INFINITY, f64::min);
277
278                if min_err <= 0.0 || !min_err.is_finite() {
279                    // Fall back to equal weights
280                    return vec![1.0 / n as f64; n];
281                }
282
283                let inv_errors: Vec<f64> = cv_errors
284                    .iter()
285                    .map(|&e| {
286                        if e.is_finite() && e > 0.0 {
287                            1.0 / e
288                        } else {
289                            0.0
290                        }
291                    })
292                    .collect();
293
294                let sum: f64 = inv_errors.iter().sum();
295                if sum > 0.0 {
296                    inv_errors.iter().map(|&w| w / sum).collect()
297                } else {
298                    vec![1.0 / n as f64; n]
299                }
300            }
301        }
302    }
303
304    /// Get the weights of each model in the ensemble
305    pub fn model_weights(&self) -> Vec<(String, f64)> {
306        self.members
307            .iter()
308            .map(|m| (m.name.clone(), m.weight))
309            .collect()
310    }
311
312    /// Get the number of active models in the ensemble
313    pub fn n_active_models(&self) -> usize {
314        self.members
315            .iter()
316            .filter(|m| m.weight >= self.options.min_weight)
317            .count()
318    }
319}
320
321impl SurrogateModel for EnsembleSurrogate {
322    fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> OptimizeResult<()> {
323        let n = x.nrows();
324        if n < 2 {
325            return Err(OptimizeError::InvalidInput(
326                "Need at least 2 data points for ensemble".to_string(),
327            ));
328        }
329
330        self.x_train_raw = Some(x.clone());
331        self.y_train_raw = Some(y.clone());
332
333        // Create fresh members
334        let member_specs = self.create_members();
335        let n_models = member_specs.len();
336
337        if n_models == 0 {
338            return Err(OptimizeError::InvalidInput(
339                "No models enabled for ensemble".to_string(),
340            ));
341        }
342
343        // Fit each model and compute CV error
344        let mut fitted_models: Vec<(Box<dyn SurrogateModel>, String)> = Vec::new();
345        let mut cv_errors: Vec<f64> = Vec::new();
346
347        for (mut model, name) in member_specs {
348            if model.fit(x, y).is_ok() {
349                // Compute CV error based on criterion
350                let cv_err = match self.options.criterion {
351                    ModelSelectionCriterion::Loocv => {
352                        // Approximate LOOCV by computing training error with leave-one-out
353                        if n >= 3 {
354                            let mut total_sq_err = 0.0;
355                            let mut count = 0;
356                            // Use a subset for speed if n is large
357                            let step = if n > 20 { n / 10 } else { 1 };
358                            for i in (0..n).step_by(step) {
359                                let x_i = x.row(i).to_owned();
360                                if let Ok(pred) = model.predict(&x_i) {
361                                    let err = pred - y[i];
362                                    total_sq_err += err * err;
363                                    count += 1;
364                                }
365                            }
366                            // Training error is optimistic; scale up
367                            if count > 0 {
368                                total_sq_err / count as f64 * (n as f64 / (n as f64 - 1.0))
369                            } else {
370                                f64::INFINITY
371                            }
372                        } else {
373                            1.0 // default
374                        }
375                    }
376                    ModelSelectionCriterion::KFold { k } => {
377                        let actual_k = k.min(n).max(2);
378                        let fold_size = n / actual_k;
379                        let mut total_err = 0.0;
380                        let mut count = 0;
381
382                        for fold in 0..actual_k {
383                            let test_start = fold * fold_size;
384                            let test_end = if fold == actual_k - 1 {
385                                n
386                            } else {
387                                (fold + 1) * fold_size
388                            };
389
390                            for i in test_start..test_end {
391                                let x_i = x.row(i).to_owned();
392                                if let Ok(pred) = model.predict(&x_i) {
393                                    let err = pred - y[i];
394                                    total_err += err * err;
395                                    count += 1;
396                                }
397                            }
398                        }
399                        if count > 0 {
400                            total_err / count as f64
401                        } else {
402                            f64::INFINITY
403                        }
404                    }
405                    ModelSelectionCriterion::Aic => {
406                        // AIC approximation: n * ln(MSE) + 2 * k
407                        let mut mse = 0.0;
408                        for i in 0..n {
409                            let x_i = x.row(i).to_owned();
410                            if let Ok(pred) = model.predict(&x_i) {
411                                mse += (pred - y[i]).powi(2);
412                            }
413                        }
414                        mse /= n as f64;
415                        if mse > 0.0 {
416                            n as f64 * mse.ln() + 2.0 * x.ncols() as f64
417                        } else {
418                            f64::NEG_INFINITY
419                        }
420                    }
421                    ModelSelectionCriterion::Equal | ModelSelectionCriterion::BestSingle => 1.0,
422                };
423
424                cv_errors.push(cv_err);
425                fitted_models.push((model, name));
426            }
427        }
428
429        if fitted_models.is_empty() {
430            return Err(OptimizeError::ComputationError(
431                "All ensemble models failed to fit".to_string(),
432            ));
433        }
434
435        // Compute weights
436        let weights = self.compute_weights(&cv_errors);
437
438        // Build ensemble members
439        self.members.clear();
440        for ((model, name), weight) in fitted_models.into_iter().zip(weights.into_iter()) {
441            self.members.push(EnsembleMember {
442                model,
443                name,
444                weight,
445            });
446        }
447
448        Ok(())
449    }
450
451    fn predict(&self, x: &Array1<f64>) -> OptimizeResult<f64> {
452        if self.members.is_empty() {
453            return Err(OptimizeError::ComputationError(
454                "Ensemble not fitted".to_string(),
455            ));
456        }
457
458        let mut prediction = 0.0;
459        let mut weight_sum = 0.0;
460
461        for member in &self.members {
462            if member.weight >= self.options.min_weight {
463                if let Ok(pred) = member.model.predict(x) {
464                    prediction += member.weight * pred;
465                    weight_sum += member.weight;
466                }
467            }
468        }
469
470        if weight_sum > 0.0 {
471            Ok(prediction / weight_sum)
472        } else {
473            Err(OptimizeError::ComputationError(
474                "No ensemble members produced valid predictions".to_string(),
475            ))
476        }
477    }
478
479    fn predict_with_uncertainty(&self, x: &Array1<f64>) -> OptimizeResult<(f64, f64)> {
480        if self.members.is_empty() {
481            return Err(OptimizeError::ComputationError(
482                "Ensemble not fitted".to_string(),
483            ));
484        }
485
486        let mut mean = 0.0;
487        let mut weight_sum = 0.0;
488        let mut predictions = Vec::new();
489        let mut weights_used = Vec::new();
490
491        for member in &self.members {
492            if member.weight >= self.options.min_weight {
493                if let Ok((pred, _unc)) = member.model.predict_with_uncertainty(x) {
494                    mean += member.weight * pred;
495                    weight_sum += member.weight;
496                    predictions.push(pred);
497                    weights_used.push(member.weight);
498                }
499            }
500        }
501
502        if weight_sum <= 0.0 {
503            return Err(OptimizeError::ComputationError(
504                "No ensemble members produced valid predictions".to_string(),
505            ));
506        }
507
508        mean /= weight_sum;
509
510        // Uncertainty: combination of individual uncertainties and model disagreement
511        let mut variance = 0.0;
512        for (pred, w) in predictions.iter().zip(weights_used.iter()) {
513            let diff = pred - mean;
514            variance += (w / weight_sum) * diff * diff;
515        }
516
517        // Add mean uncertainty from individual models
518        let mut mean_unc = 0.0;
519        for member in &self.members {
520            if member.weight >= self.options.min_weight {
521                if let Ok((_pred, unc)) = member.model.predict_with_uncertainty(x) {
522                    mean_unc += member.weight * unc;
523                }
524            }
525        }
526        mean_unc /= weight_sum;
527
528        let total_std = (variance + mean_unc * mean_unc).sqrt().max(1e-10);
529        Ok((mean, total_std))
530    }
531
532    fn n_samples(&self) -> usize {
533        self.x_train_raw.as_ref().map_or(0, |x| x.nrows())
534    }
535
536    fn n_features(&self) -> usize {
537        self.x_train_raw.as_ref().map_or(0, |x| x.ncols())
538    }
539
540    fn update(&mut self, x: &Array1<f64>, y: f64) -> OptimizeResult<()> {
541        // Refit with new data
542        let (new_x, new_y) =
543            if let (Some(ref x_raw), Some(ref y_raw)) = (&self.x_train_raw, &self.y_train_raw) {
544                let n = x_raw.nrows();
545                let d = x_raw.ncols();
546
547                let mut new_x = Array2::zeros((n + 1, d));
548                for i in 0..n {
549                    for j in 0..d {
550                        new_x[[i, j]] = x_raw[[i, j]];
551                    }
552                }
553                for j in 0..d {
554                    new_x[[n, j]] = x[j];
555                }
556
557                let mut new_y = Array1::zeros(n + 1);
558                for i in 0..n {
559                    new_y[i] = y_raw[i];
560                }
561                new_y[n] = y;
562
563                (new_x, new_y)
564            } else {
565                let d = x.len();
566                let mut new_x = Array2::zeros((1, d));
567                for j in 0..d {
568                    new_x[[0, j]] = x[j];
569                }
570                (new_x, Array1::from_vec(vec![y]))
571            };
572
573        self.fit(&new_x, &new_y)
574    }
575}
576
577#[cfg(test)]
578mod tests {
579    use super::*;
580
581    #[test]
582    fn test_ensemble_basic() {
583        let x_train = Array2::from_shape_vec((6, 1), vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
584            .expect("Array creation failed");
585        let y_train = Array1::from_vec(vec![0.0, 0.4, 1.6, 3.6, 6.4, 10.0]);
586
587        let mut ensemble = EnsembleSurrogate::new(EnsembleOptions {
588            criterion: ModelSelectionCriterion::Equal,
589            include_kriging_se: false,       // skip for speed
590            include_kriging_matern52: false, // skip for speed
591            include_rbf_multiquadric: false,
592            ..Default::default()
593        });
594
595        let result = ensemble.fit(&x_train, &y_train);
596        assert!(result.is_ok(), "Ensemble fit failed: {:?}", result.err());
597
598        // Predict
599        let pred = ensemble.predict(&Array1::from_vec(vec![0.5]));
600        assert!(pred.is_ok());
601        let val = pred.expect("Ensemble prediction failed");
602        // f(0.5) = 0.5^2 * 10 = 2.5 approximately
603        assert!(
604            val.abs() < 20.0,
605            "Ensemble prediction out of range: {}",
606            val
607        );
608    }
609
610    #[test]
611    fn test_ensemble_with_kriging() {
612        let x_train = Array2::from_shape_vec((5, 1), vec![0.0, 0.25, 0.5, 0.75, 1.0])
613            .expect("Array creation failed");
614        let y_train = Array1::from_vec(vec![0.0, 0.0625, 0.25, 0.5625, 1.0]);
615
616        let mut ensemble = EnsembleSurrogate::new(EnsembleOptions {
617            criterion: ModelSelectionCriterion::Equal,
618            include_rbf_tps: false,
619            ..Default::default()
620        });
621
622        assert!(ensemble.fit(&x_train, &y_train).is_ok());
623        assert!(ensemble.n_active_models() > 0);
624    }
625
626    #[test]
627    fn test_ensemble_uncertainty() {
628        let x_train = Array2::from_shape_vec((4, 1), vec![0.0, 0.33, 0.66, 1.0])
629            .expect("Array creation failed");
630        let y_train = Array1::from_vec(vec![0.0, 1.0, 1.0, 0.0]);
631
632        let mut ensemble = EnsembleSurrogate::new(EnsembleOptions {
633            criterion: ModelSelectionCriterion::Equal,
634            include_kriging_se: false,
635            include_kriging_matern52: false,
636            ..Default::default()
637        });
638        ensemble.fit(&x_train, &y_train).expect("Fit failed");
639
640        let result = ensemble.predict_with_uncertainty(&Array1::from_vec(vec![0.5]));
641        assert!(result.is_ok());
642        let (mean, std) = result.expect("Uncertainty prediction failed");
643        assert!(std > 0.0, "Uncertainty should be positive: {}", std);
644        assert!(mean.is_finite(), "Mean should be finite: {}", mean);
645    }
646
647    #[test]
648    fn test_ensemble_best_single() {
649        let x_train = Array2::from_shape_vec((5, 1), vec![0.0, 0.25, 0.5, 0.75, 1.0])
650            .expect("Array creation failed");
651        let y_train = Array1::from_vec(vec![1.0, 2.0, 3.0, 2.0, 1.0]);
652
653        let mut ensemble = EnsembleSurrogate::new(EnsembleOptions {
654            criterion: ModelSelectionCriterion::BestSingle,
655            include_kriging_se: false,
656            include_kriging_matern52: false,
657            ..Default::default()
658        });
659        ensemble.fit(&x_train, &y_train).expect("Fit failed");
660
661        // Only one model should have weight 1.0
662        let weights = ensemble.model_weights();
663        let n_nonzero = weights.iter().filter(|(_, w)| *w > 0.0).count();
664        assert_eq!(
665            n_nonzero, 1,
666            "BestSingle should have exactly 1 active model"
667        );
668    }
669
670    #[test]
671    fn test_ensemble_update() {
672        let x_train = Array2::from_shape_vec((4, 1), vec![0.0, 0.33, 0.66, 1.0])
673            .expect("Array creation failed");
674        let y_train = Array1::from_vec(vec![0.0, 1.0, 1.0, 0.0]);
675
676        let mut ensemble = EnsembleSurrogate::new(EnsembleOptions {
677            criterion: ModelSelectionCriterion::Equal,
678            include_kriging_se: false,
679            include_kriging_matern52: false,
680            ..Default::default()
681        });
682        ensemble.fit(&x_train, &y_train).expect("Fit failed");
683        assert_eq!(ensemble.n_samples(), 4);
684
685        ensemble
686            .update(&Array1::from_vec(vec![0.5]), 1.0)
687            .expect("Update failed");
688        assert_eq!(ensemble.n_samples(), 5);
689    }
690
691    #[test]
692    fn test_ensemble_2d() {
693        let x_train = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
694            .expect("Array creation failed");
695        let y_train = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0]);
696
697        let mut ensemble = EnsembleSurrogate::new(EnsembleOptions {
698            criterion: ModelSelectionCriterion::Equal,
699            include_kriging_se: false,
700            include_kriging_matern52: false,
701            ..Default::default()
702        });
703        assert!(ensemble.fit(&x_train, &y_train).is_ok());
704
705        let pred = ensemble.predict(&Array1::from_vec(vec![0.5, 0.5]));
706        assert!(pred.is_ok());
707    }
708
709    #[test]
710    fn test_ensemble_loocv_criterion() {
711        let x_train = Array2::from_shape_vec((6, 1), vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
712            .expect("Array creation failed");
713        let y_train = Array1::from_vec(vec![0.0, 0.04, 0.16, 0.36, 0.64, 1.0]);
714
715        let mut ensemble = EnsembleSurrogate::new(EnsembleOptions {
716            criterion: ModelSelectionCriterion::Loocv,
717            include_kriging_se: false,
718            include_kriging_matern52: false,
719            ..Default::default()
720        });
721        assert!(ensemble.fit(&x_train, &y_train).is_ok());
722
723        let weights = ensemble.model_weights();
724        let total_weight: f64 = weights.iter().map(|(_, w)| w).sum();
725        assert!(
726            (total_weight - 1.0).abs() < 0.01,
727            "Weights should sum to ~1.0, got {}",
728            total_weight
729        );
730    }
731
732    #[test]
733    fn test_ensemble_kfold_criterion() {
734        let x_train = Array2::from_shape_vec((6, 1), vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
735            .expect("Array creation failed");
736        let y_train = Array1::from_vec(vec![0.0, 0.04, 0.16, 0.36, 0.64, 1.0]);
737
738        let mut ensemble = EnsembleSurrogate::new(EnsembleOptions {
739            criterion: ModelSelectionCriterion::KFold { k: 3 },
740            include_kriging_se: false,
741            include_kriging_matern52: false,
742            ..Default::default()
743        });
744        assert!(ensemble.fit(&x_train, &y_train).is_ok());
745    }
746}