1use super::{
22 kriging::{CorrelationFunction, KrigingOptions, KrigingSurrogate},
23 rbf_surrogate::{RbfKernel, RbfOptions, RbfSurrogate},
24 SurrogateModel,
25};
26use crate::error::{OptimizeError, OptimizeResult};
27use scirs2_core::ndarray::{Array1, Array2};
28
29#[derive(Debug, Clone, Copy, PartialEq)]
31pub enum ModelSelectionCriterion {
32 Loocv,
34 KFold {
36 k: usize,
38 },
39 Aic,
41 Equal,
43 BestSingle,
45}
46
47impl Default for ModelSelectionCriterion {
48 fn default() -> Self {
49 ModelSelectionCriterion::Loocv
50 }
51}
52
53#[derive(Debug, Clone)]
55pub struct EnsembleOptions {
56 pub criterion: ModelSelectionCriterion,
58 pub include_rbf_cubic: bool,
60 pub include_rbf_gaussian: bool,
62 pub include_rbf_multiquadric: bool,
64 pub include_rbf_tps: bool,
66 pub include_kriging_se: bool,
68 pub include_kriging_matern52: bool,
70 pub min_weight: f64,
72 pub seed: Option<u64>,
74}
75
76impl Default for EnsembleOptions {
77 fn default() -> Self {
78 Self {
79 criterion: ModelSelectionCriterion::default(),
80 include_rbf_cubic: true,
81 include_rbf_gaussian: true,
82 include_rbf_multiquadric: false,
83 include_rbf_tps: true,
84 include_kriging_se: true,
85 include_kriging_matern52: true,
86 min_weight: 0.01,
87 seed: None,
88 }
89 }
90}
91
92struct EnsembleMember {
94 model: Box<dyn SurrogateModel>,
96 name: String,
98 weight: f64,
100}
101
102pub struct EnsembleSurrogate {
104 options: EnsembleOptions,
105 members: Vec<EnsembleMember>,
107 x_train_raw: Option<Array2<f64>>,
109 y_train_raw: Option<Array1<f64>>,
110}
111
112impl EnsembleSurrogate {
113 pub fn new(options: EnsembleOptions) -> Self {
115 Self {
116 options,
117 members: Vec::new(),
118 x_train_raw: None,
119 y_train_raw: None,
120 }
121 }
122
123 fn create_members(&self) -> Vec<(Box<dyn SurrogateModel>, String)> {
125 let mut members: Vec<(Box<dyn SurrogateModel>, String)> = Vec::new();
126
127 if self.options.include_rbf_cubic {
128 members.push((
129 Box::new(RbfSurrogate::new(RbfOptions {
130 kernel: RbfKernel::Polyharmonic(3),
131 regularization: 1e-8,
132 normalize: true,
133 })),
134 "RBF-Cubic".to_string(),
135 ));
136 }
137
138 if self.options.include_rbf_gaussian {
139 members.push((
140 Box::new(RbfSurrogate::new(RbfOptions {
141 kernel: RbfKernel::Gaussian { sigma: 1.0 },
142 regularization: 1e-6,
143 normalize: true,
144 })),
145 "RBF-Gaussian".to_string(),
146 ));
147 }
148
149 if self.options.include_rbf_multiquadric {
150 members.push((
151 Box::new(RbfSurrogate::new(RbfOptions {
152 kernel: RbfKernel::Multiquadric { shape_param: 1.0 },
153 regularization: 1e-8,
154 normalize: true,
155 })),
156 "RBF-MQ".to_string(),
157 ));
158 }
159
160 if self.options.include_rbf_tps {
161 members.push((
162 Box::new(RbfSurrogate::new(RbfOptions {
163 kernel: RbfKernel::ThinPlateSpline,
164 regularization: 1e-8,
165 normalize: true,
166 })),
167 "RBF-TPS".to_string(),
168 ));
169 }
170
171 if self.options.include_kriging_se {
172 members.push((
173 Box::new(KrigingSurrogate::new(KrigingOptions {
174 correlation: CorrelationFunction::SquaredExponential,
175 nugget: Some(1e-4),
176 n_restarts: 3,
177 seed: self.options.seed,
178 ..Default::default()
179 })),
180 "Kriging-SE".to_string(),
181 ));
182 }
183
184 if self.options.include_kriging_matern52 {
185 members.push((
186 Box::new(KrigingSurrogate::new(KrigingOptions {
187 correlation: CorrelationFunction::Matern52,
188 nugget: Some(1e-4),
189 n_restarts: 3,
190 seed: self.options.seed,
191 ..Default::default()
192 })),
193 "Kriging-Matern52".to_string(),
194 ));
195 }
196
197 members
198 }
199
200 fn loocv_error(
202 &self,
203 model_factory: &dyn Fn() -> Box<dyn SurrogateModel>,
204 x: &Array2<f64>,
205 y: &Array1<f64>,
206 ) -> f64 {
207 let n = x.nrows();
208 let d = x.ncols();
209
210 if n < 3 {
211 return f64::INFINITY;
212 }
213
214 let mut total_sq_error = 0.0;
215 let mut valid_count = 0;
216
217 for leave_out in 0..n {
218 let mut x_train = Array2::zeros((n - 1, d));
220 let mut y_train = Array1::zeros(n - 1);
221 let mut idx = 0;
222 for i in 0..n {
223 if i != leave_out {
224 for j in 0..d {
225 x_train[[idx, j]] = x[[i, j]];
226 }
227 y_train[idx] = y[i];
228 idx += 1;
229 }
230 }
231
232 let mut model = model_factory();
233 if model.fit(&x_train, &y_train).is_ok() {
234 let x_test = x.row(leave_out).to_owned();
235 if let Ok(pred) = model.predict(&x_test) {
236 let error = pred - y[leave_out];
237 total_sq_error += error * error;
238 valid_count += 1;
239 }
240 }
241 }
242
243 if valid_count > 0 {
244 total_sq_error / valid_count as f64
245 } else {
246 f64::INFINITY
247 }
248 }
249
250 fn compute_weights(&self, cv_errors: &[f64]) -> Vec<f64> {
252 let n = cv_errors.len();
253 if n == 0 {
254 return Vec::new();
255 }
256
257 match self.options.criterion {
258 ModelSelectionCriterion::Equal => {
259 vec![1.0 / n as f64; n]
260 }
261 ModelSelectionCriterion::BestSingle => {
262 let mut weights = vec![0.0; n];
263 let mut best_idx = 0;
264 let mut best_err = f64::INFINITY;
265 for (i, &err) in cv_errors.iter().enumerate() {
266 if err < best_err {
267 best_err = err;
268 best_idx = i;
269 }
270 }
271 weights[best_idx] = 1.0;
272 weights
273 }
274 _ => {
275 let min_err = cv_errors.iter().copied().fold(f64::INFINITY, f64::min);
277
278 if min_err <= 0.0 || !min_err.is_finite() {
279 return vec![1.0 / n as f64; n];
281 }
282
283 let inv_errors: Vec<f64> = cv_errors
284 .iter()
285 .map(|&e| {
286 if e.is_finite() && e > 0.0 {
287 1.0 / e
288 } else {
289 0.0
290 }
291 })
292 .collect();
293
294 let sum: f64 = inv_errors.iter().sum();
295 if sum > 0.0 {
296 inv_errors.iter().map(|&w| w / sum).collect()
297 } else {
298 vec![1.0 / n as f64; n]
299 }
300 }
301 }
302 }
303
304 pub fn model_weights(&self) -> Vec<(String, f64)> {
306 self.members
307 .iter()
308 .map(|m| (m.name.clone(), m.weight))
309 .collect()
310 }
311
312 pub fn n_active_models(&self) -> usize {
314 self.members
315 .iter()
316 .filter(|m| m.weight >= self.options.min_weight)
317 .count()
318 }
319}
320
321impl SurrogateModel for EnsembleSurrogate {
322 fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> OptimizeResult<()> {
323 let n = x.nrows();
324 if n < 2 {
325 return Err(OptimizeError::InvalidInput(
326 "Need at least 2 data points for ensemble".to_string(),
327 ));
328 }
329
330 self.x_train_raw = Some(x.clone());
331 self.y_train_raw = Some(y.clone());
332
333 let member_specs = self.create_members();
335 let n_models = member_specs.len();
336
337 if n_models == 0 {
338 return Err(OptimizeError::InvalidInput(
339 "No models enabled for ensemble".to_string(),
340 ));
341 }
342
343 let mut fitted_models: Vec<(Box<dyn SurrogateModel>, String)> = Vec::new();
345 let mut cv_errors: Vec<f64> = Vec::new();
346
347 for (mut model, name) in member_specs {
348 if model.fit(x, y).is_ok() {
349 let cv_err = match self.options.criterion {
351 ModelSelectionCriterion::Loocv => {
352 if n >= 3 {
354 let mut total_sq_err = 0.0;
355 let mut count = 0;
356 let step = if n > 20 { n / 10 } else { 1 };
358 for i in (0..n).step_by(step) {
359 let x_i = x.row(i).to_owned();
360 if let Ok(pred) = model.predict(&x_i) {
361 let err = pred - y[i];
362 total_sq_err += err * err;
363 count += 1;
364 }
365 }
366 if count > 0 {
368 total_sq_err / count as f64 * (n as f64 / (n as f64 - 1.0))
369 } else {
370 f64::INFINITY
371 }
372 } else {
373 1.0 }
375 }
376 ModelSelectionCriterion::KFold { k } => {
377 let actual_k = k.min(n).max(2);
378 let fold_size = n / actual_k;
379 let mut total_err = 0.0;
380 let mut count = 0;
381
382 for fold in 0..actual_k {
383 let test_start = fold * fold_size;
384 let test_end = if fold == actual_k - 1 {
385 n
386 } else {
387 (fold + 1) * fold_size
388 };
389
390 for i in test_start..test_end {
391 let x_i = x.row(i).to_owned();
392 if let Ok(pred) = model.predict(&x_i) {
393 let err = pred - y[i];
394 total_err += err * err;
395 count += 1;
396 }
397 }
398 }
399 if count > 0 {
400 total_err / count as f64
401 } else {
402 f64::INFINITY
403 }
404 }
405 ModelSelectionCriterion::Aic => {
406 let mut mse = 0.0;
408 for i in 0..n {
409 let x_i = x.row(i).to_owned();
410 if let Ok(pred) = model.predict(&x_i) {
411 mse += (pred - y[i]).powi(2);
412 }
413 }
414 mse /= n as f64;
415 if mse > 0.0 {
416 n as f64 * mse.ln() + 2.0 * x.ncols() as f64
417 } else {
418 f64::NEG_INFINITY
419 }
420 }
421 ModelSelectionCriterion::Equal | ModelSelectionCriterion::BestSingle => 1.0,
422 };
423
424 cv_errors.push(cv_err);
425 fitted_models.push((model, name));
426 }
427 }
428
429 if fitted_models.is_empty() {
430 return Err(OptimizeError::ComputationError(
431 "All ensemble models failed to fit".to_string(),
432 ));
433 }
434
435 let weights = self.compute_weights(&cv_errors);
437
438 self.members.clear();
440 for ((model, name), weight) in fitted_models.into_iter().zip(weights.into_iter()) {
441 self.members.push(EnsembleMember {
442 model,
443 name,
444 weight,
445 });
446 }
447
448 Ok(())
449 }
450
451 fn predict(&self, x: &Array1<f64>) -> OptimizeResult<f64> {
452 if self.members.is_empty() {
453 return Err(OptimizeError::ComputationError(
454 "Ensemble not fitted".to_string(),
455 ));
456 }
457
458 let mut prediction = 0.0;
459 let mut weight_sum = 0.0;
460
461 for member in &self.members {
462 if member.weight >= self.options.min_weight {
463 if let Ok(pred) = member.model.predict(x) {
464 prediction += member.weight * pred;
465 weight_sum += member.weight;
466 }
467 }
468 }
469
470 if weight_sum > 0.0 {
471 Ok(prediction / weight_sum)
472 } else {
473 Err(OptimizeError::ComputationError(
474 "No ensemble members produced valid predictions".to_string(),
475 ))
476 }
477 }
478
479 fn predict_with_uncertainty(&self, x: &Array1<f64>) -> OptimizeResult<(f64, f64)> {
480 if self.members.is_empty() {
481 return Err(OptimizeError::ComputationError(
482 "Ensemble not fitted".to_string(),
483 ));
484 }
485
486 let mut mean = 0.0;
487 let mut weight_sum = 0.0;
488 let mut predictions = Vec::new();
489 let mut weights_used = Vec::new();
490
491 for member in &self.members {
492 if member.weight >= self.options.min_weight {
493 if let Ok((pred, _unc)) = member.model.predict_with_uncertainty(x) {
494 mean += member.weight * pred;
495 weight_sum += member.weight;
496 predictions.push(pred);
497 weights_used.push(member.weight);
498 }
499 }
500 }
501
502 if weight_sum <= 0.0 {
503 return Err(OptimizeError::ComputationError(
504 "No ensemble members produced valid predictions".to_string(),
505 ));
506 }
507
508 mean /= weight_sum;
509
510 let mut variance = 0.0;
512 for (pred, w) in predictions.iter().zip(weights_used.iter()) {
513 let diff = pred - mean;
514 variance += (w / weight_sum) * diff * diff;
515 }
516
517 let mut mean_unc = 0.0;
519 for member in &self.members {
520 if member.weight >= self.options.min_weight {
521 if let Ok((_pred, unc)) = member.model.predict_with_uncertainty(x) {
522 mean_unc += member.weight * unc;
523 }
524 }
525 }
526 mean_unc /= weight_sum;
527
528 let total_std = (variance + mean_unc * mean_unc).sqrt().max(1e-10);
529 Ok((mean, total_std))
530 }
531
532 fn n_samples(&self) -> usize {
533 self.x_train_raw.as_ref().map_or(0, |x| x.nrows())
534 }
535
536 fn n_features(&self) -> usize {
537 self.x_train_raw.as_ref().map_or(0, |x| x.ncols())
538 }
539
540 fn update(&mut self, x: &Array1<f64>, y: f64) -> OptimizeResult<()> {
541 let (new_x, new_y) =
543 if let (Some(ref x_raw), Some(ref y_raw)) = (&self.x_train_raw, &self.y_train_raw) {
544 let n = x_raw.nrows();
545 let d = x_raw.ncols();
546
547 let mut new_x = Array2::zeros((n + 1, d));
548 for i in 0..n {
549 for j in 0..d {
550 new_x[[i, j]] = x_raw[[i, j]];
551 }
552 }
553 for j in 0..d {
554 new_x[[n, j]] = x[j];
555 }
556
557 let mut new_y = Array1::zeros(n + 1);
558 for i in 0..n {
559 new_y[i] = y_raw[i];
560 }
561 new_y[n] = y;
562
563 (new_x, new_y)
564 } else {
565 let d = x.len();
566 let mut new_x = Array2::zeros((1, d));
567 for j in 0..d {
568 new_x[[0, j]] = x[j];
569 }
570 (new_x, Array1::from_vec(vec![y]))
571 };
572
573 self.fit(&new_x, &new_y)
574 }
575}
576
577#[cfg(test)]
578mod tests {
579 use super::*;
580
581 #[test]
582 fn test_ensemble_basic() {
583 let x_train = Array2::from_shape_vec((6, 1), vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
584 .expect("Array creation failed");
585 let y_train = Array1::from_vec(vec![0.0, 0.4, 1.6, 3.6, 6.4, 10.0]);
586
587 let mut ensemble = EnsembleSurrogate::new(EnsembleOptions {
588 criterion: ModelSelectionCriterion::Equal,
589 include_kriging_se: false, include_kriging_matern52: false, include_rbf_multiquadric: false,
592 ..Default::default()
593 });
594
595 let result = ensemble.fit(&x_train, &y_train);
596 assert!(result.is_ok(), "Ensemble fit failed: {:?}", result.err());
597
598 let pred = ensemble.predict(&Array1::from_vec(vec![0.5]));
600 assert!(pred.is_ok());
601 let val = pred.expect("Ensemble prediction failed");
602 assert!(
604 val.abs() < 20.0,
605 "Ensemble prediction out of range: {}",
606 val
607 );
608 }
609
610 #[test]
611 fn test_ensemble_with_kriging() {
612 let x_train = Array2::from_shape_vec((5, 1), vec![0.0, 0.25, 0.5, 0.75, 1.0])
613 .expect("Array creation failed");
614 let y_train = Array1::from_vec(vec![0.0, 0.0625, 0.25, 0.5625, 1.0]);
615
616 let mut ensemble = EnsembleSurrogate::new(EnsembleOptions {
617 criterion: ModelSelectionCriterion::Equal,
618 include_rbf_tps: false,
619 ..Default::default()
620 });
621
622 assert!(ensemble.fit(&x_train, &y_train).is_ok());
623 assert!(ensemble.n_active_models() > 0);
624 }
625
626 #[test]
627 fn test_ensemble_uncertainty() {
628 let x_train = Array2::from_shape_vec((4, 1), vec![0.0, 0.33, 0.66, 1.0])
629 .expect("Array creation failed");
630 let y_train = Array1::from_vec(vec![0.0, 1.0, 1.0, 0.0]);
631
632 let mut ensemble = EnsembleSurrogate::new(EnsembleOptions {
633 criterion: ModelSelectionCriterion::Equal,
634 include_kriging_se: false,
635 include_kriging_matern52: false,
636 ..Default::default()
637 });
638 ensemble.fit(&x_train, &y_train).expect("Fit failed");
639
640 let result = ensemble.predict_with_uncertainty(&Array1::from_vec(vec![0.5]));
641 assert!(result.is_ok());
642 let (mean, std) = result.expect("Uncertainty prediction failed");
643 assert!(std > 0.0, "Uncertainty should be positive: {}", std);
644 assert!(mean.is_finite(), "Mean should be finite: {}", mean);
645 }
646
647 #[test]
648 fn test_ensemble_best_single() {
649 let x_train = Array2::from_shape_vec((5, 1), vec![0.0, 0.25, 0.5, 0.75, 1.0])
650 .expect("Array creation failed");
651 let y_train = Array1::from_vec(vec![1.0, 2.0, 3.0, 2.0, 1.0]);
652
653 let mut ensemble = EnsembleSurrogate::new(EnsembleOptions {
654 criterion: ModelSelectionCriterion::BestSingle,
655 include_kriging_se: false,
656 include_kriging_matern52: false,
657 ..Default::default()
658 });
659 ensemble.fit(&x_train, &y_train).expect("Fit failed");
660
661 let weights = ensemble.model_weights();
663 let n_nonzero = weights.iter().filter(|(_, w)| *w > 0.0).count();
664 assert_eq!(
665 n_nonzero, 1,
666 "BestSingle should have exactly 1 active model"
667 );
668 }
669
670 #[test]
671 fn test_ensemble_update() {
672 let x_train = Array2::from_shape_vec((4, 1), vec![0.0, 0.33, 0.66, 1.0])
673 .expect("Array creation failed");
674 let y_train = Array1::from_vec(vec![0.0, 1.0, 1.0, 0.0]);
675
676 let mut ensemble = EnsembleSurrogate::new(EnsembleOptions {
677 criterion: ModelSelectionCriterion::Equal,
678 include_kriging_se: false,
679 include_kriging_matern52: false,
680 ..Default::default()
681 });
682 ensemble.fit(&x_train, &y_train).expect("Fit failed");
683 assert_eq!(ensemble.n_samples(), 4);
684
685 ensemble
686 .update(&Array1::from_vec(vec![0.5]), 1.0)
687 .expect("Update failed");
688 assert_eq!(ensemble.n_samples(), 5);
689 }
690
691 #[test]
692 fn test_ensemble_2d() {
693 let x_train = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
694 .expect("Array creation failed");
695 let y_train = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0]);
696
697 let mut ensemble = EnsembleSurrogate::new(EnsembleOptions {
698 criterion: ModelSelectionCriterion::Equal,
699 include_kriging_se: false,
700 include_kriging_matern52: false,
701 ..Default::default()
702 });
703 assert!(ensemble.fit(&x_train, &y_train).is_ok());
704
705 let pred = ensemble.predict(&Array1::from_vec(vec![0.5, 0.5]));
706 assert!(pred.is_ok());
707 }
708
709 #[test]
710 fn test_ensemble_loocv_criterion() {
711 let x_train = Array2::from_shape_vec((6, 1), vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
712 .expect("Array creation failed");
713 let y_train = Array1::from_vec(vec![0.0, 0.04, 0.16, 0.36, 0.64, 1.0]);
714
715 let mut ensemble = EnsembleSurrogate::new(EnsembleOptions {
716 criterion: ModelSelectionCriterion::Loocv,
717 include_kriging_se: false,
718 include_kriging_matern52: false,
719 ..Default::default()
720 });
721 assert!(ensemble.fit(&x_train, &y_train).is_ok());
722
723 let weights = ensemble.model_weights();
724 let total_weight: f64 = weights.iter().map(|(_, w)| w).sum();
725 assert!(
726 (total_weight - 1.0).abs() < 0.01,
727 "Weights should sum to ~1.0, got {}",
728 total_weight
729 );
730 }
731
732 #[test]
733 fn test_ensemble_kfold_criterion() {
734 let x_train = Array2::from_shape_vec((6, 1), vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
735 .expect("Array creation failed");
736 let y_train = Array1::from_vec(vec![0.0, 0.04, 0.16, 0.36, 0.64, 1.0]);
737
738 let mut ensemble = EnsembleSurrogate::new(EnsembleOptions {
739 criterion: ModelSelectionCriterion::KFold { k: 3 },
740 include_kriging_se: false,
741 include_kriging_matern52: false,
742 ..Default::default()
743 });
744 assert!(ensemble.fit(&x_train, &y_train).is_ok());
745 }
746}