Skip to main content

scirs2_optimize/bayesian/
optimizer.rs

1//! Bayesian Optimizer -- the main driver for Bayesian optimization.
2//!
3//! Orchestrates the GP surrogate, acquisition function, and sampling strategy
4//! into a full sequential/batch optimization loop.
5//!
6//! # Features
7//!
8//! - Configurable surrogate (GP with any kernel)
9//! - Pluggable acquisition functions (EI, PI, UCB, KG, Thompson, batch variants)
10//! - Initial design via Latin Hypercube, Sobol, Halton, or random sampling
11//! - Sequential and batch optimization loops
12//! - Multi-objective Bayesian optimization via ParEGO scalarization
13//! - Constraint handling via augmented acquisition
14//! - Warm-starting from previous evaluations
15
16use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
17use scirs2_core::random::rngs::StdRng;
18use scirs2_core::random::{Rng, RngExt, SeedableRng};
19
20use crate::error::{OptimizeError, OptimizeResult};
21
22use super::acquisition::{AcquisitionFn, AcquisitionType, ExpectedImprovement};
23use super::gp::{GpSurrogate, GpSurrogateConfig, RbfKernel, SurrogateKernel};
24use super::sampling::{generate_samples, SamplingConfig, SamplingStrategy};
25
26// ---------------------------------------------------------------------------
27// Configuration
28// ---------------------------------------------------------------------------
29
30/// Configuration for the Bayesian optimizer.
31#[derive(Clone)]
32pub struct BayesianOptimizerConfig {
33    /// Acquisition function type.
34    pub acquisition: AcquisitionType,
35    /// Sampling strategy for initial design.
36    pub initial_design: SamplingStrategy,
37    /// Number of initial random/quasi-random points.
38    pub n_initial: usize,
39    /// Number of restarts when optimising the acquisition function.
40    pub acq_n_restarts: usize,
41    /// Number of random candidates evaluated per restart when optimising acquisition.
42    pub acq_n_candidates: usize,
43    /// GP surrogate configuration.
44    pub gp_config: GpSurrogateConfig,
45    /// Random seed for reproducibility.
46    pub seed: Option<u64>,
47    /// Verbosity level (0 = silent, 1 = summary, 2 = per-iteration).
48    pub verbose: usize,
49}
50
51impl Default for BayesianOptimizerConfig {
52    fn default() -> Self {
53        Self {
54            acquisition: AcquisitionType::EI { xi: 0.01 },
55            initial_design: SamplingStrategy::LatinHypercube,
56            n_initial: 10,
57            acq_n_restarts: 5,
58            acq_n_candidates: 200,
59            gp_config: GpSurrogateConfig::default(),
60            seed: None,
61            verbose: 0,
62        }
63    }
64}
65
66// ---------------------------------------------------------------------------
67// Observation record
68// ---------------------------------------------------------------------------
69
70/// A single evaluated observation.
71#[derive(Debug, Clone)]
72pub struct Observation {
73    /// Input point.
74    pub x: Array1<f64>,
75    /// Objective function value.
76    pub y: f64,
77    /// Constraint violation values (empty if no constraints).
78    pub constraints: Vec<f64>,
79    /// Whether this point is feasible (all constraints satisfied).
80    pub feasible: bool,
81}
82
83// ---------------------------------------------------------------------------
84// Optimization result
85// ---------------------------------------------------------------------------
86
87/// Result of Bayesian optimization.
88#[derive(Debug, Clone)]
89pub struct BayesianOptResult {
90    /// Best input point found.
91    pub x_best: Array1<f64>,
92    /// Best objective function value found.
93    pub f_best: f64,
94    /// All observations in order.
95    pub observations: Vec<Observation>,
96    /// Number of function evaluations.
97    pub n_evals: usize,
98    /// History of best values found at each iteration.
99    pub best_history: Vec<f64>,
100    /// Whether the optimisation was successful.
101    pub success: bool,
102    /// Message about the optimization.
103    pub message: String,
104}
105
106// ---------------------------------------------------------------------------
107// Constraint specification
108// ---------------------------------------------------------------------------
109
110/// A constraint for constrained Bayesian optimization.
111///
112/// The constraint is satisfied when `g(x) <= 0`.
113pub struct Constraint {
114    /// Constraint function: returns a scalar value; satisfied when <= 0.
115    pub func: Box<dyn Fn(&ArrayView1<f64>) -> f64 + Send + Sync>,
116    /// Name for diagnostic purposes.
117    pub name: String,
118}
119
120// ---------------------------------------------------------------------------
121// BayesianOptimizer
122// ---------------------------------------------------------------------------
123
124/// The Bayesian optimizer.
125///
126/// Supports sequential single-objective, batch, multi-objective (ParEGO),
127/// and constrained optimization.
128pub struct BayesianOptimizer {
129    /// Search bounds: [(lower, upper), ...] for each dimension.
130    bounds: Vec<(f64, f64)>,
131    /// Configuration.
132    config: BayesianOptimizerConfig,
133    /// GP surrogate model.
134    surrogate: GpSurrogate,
135    /// Observations collected so far.
136    observations: Vec<Observation>,
137    /// Current best observation index.
138    best_idx: Option<usize>,
139    /// Constraints (empty for unconstrained).
140    constraints: Vec<Constraint>,
141    /// Random number generator.
142    rng: StdRng,
143}
144
145impl BayesianOptimizer {
146    /// Create a new Bayesian optimizer.
147    ///
148    /// # Arguments
149    /// * `bounds` - Search bounds for each dimension: `[(lo, hi), ...]`
150    /// * `config` - Optimizer configuration
151    pub fn new(bounds: Vec<(f64, f64)>, config: BayesianOptimizerConfig) -> OptimizeResult<Self> {
152        if bounds.is_empty() {
153            return Err(OptimizeError::InvalidInput(
154                "Bounds must have at least one dimension".to_string(),
155            ));
156        }
157        for (i, &(lo, hi)) in bounds.iter().enumerate() {
158            if lo >= hi {
159                return Err(OptimizeError::InvalidInput(format!(
160                    "Invalid bounds for dimension {}: [{}, {}]",
161                    i, lo, hi
162                )));
163            }
164        }
165
166        let seed = config.seed.unwrap_or_else(|| {
167            let s: u64 = scirs2_core::random::rng().random();
168            s
169        });
170        let rng = StdRng::seed_from_u64(seed);
171
172        let kernel: Box<dyn SurrogateKernel> = Box::new(RbfKernel::default());
173        let surrogate = GpSurrogate::new(kernel, config.gp_config.clone());
174
175        Ok(Self {
176            bounds,
177            config,
178            surrogate,
179            observations: Vec::new(),
180            best_idx: None,
181            constraints: Vec::new(),
182            rng,
183        })
184    }
185
186    /// Create a new optimizer with a custom kernel.
187    pub fn with_kernel(
188        bounds: Vec<(f64, f64)>,
189        kernel: Box<dyn SurrogateKernel>,
190        config: BayesianOptimizerConfig,
191    ) -> OptimizeResult<Self> {
192        let mut opt = Self::new(bounds, config)?;
193        opt.surrogate = GpSurrogate::new(kernel, opt.config.gp_config.clone());
194        Ok(opt)
195    }
196
197    /// Add a constraint: satisfied when `g(x) <= 0`.
198    pub fn add_constraint<F>(&mut self, name: &str, func: F)
199    where
200        F: Fn(&ArrayView1<f64>) -> f64 + Send + Sync + 'static,
201    {
202        self.constraints.push(Constraint {
203            func: Box::new(func),
204            name: name.to_string(),
205        });
206    }
207
208    /// Warm-start from previous evaluations.
209    pub fn warm_start(&mut self, x_data: &Array2<f64>, y_data: &Array1<f64>) -> OptimizeResult<()> {
210        if x_data.nrows() != y_data.len() {
211            return Err(OptimizeError::InvalidInput(
212                "x_data and y_data row counts must match".to_string(),
213            ));
214        }
215
216        for i in 0..x_data.nrows() {
217            let obs = Observation {
218                x: x_data.row(i).to_owned(),
219                y: y_data[i],
220                constraints: Vec::new(),
221                feasible: true,
222            };
223
224            // Track best
225            match self.best_idx {
226                Some(best) if obs.y < self.observations[best].y => {
227                    self.best_idx = Some(self.observations.len());
228                }
229                None => {
230                    self.best_idx = Some(self.observations.len());
231                }
232                _ => {}
233            }
234            self.observations.push(obs);
235        }
236
237        // Fit the surrogate
238        if !self.observations.is_empty() {
239            self.fit_surrogate()?;
240        }
241
242        Ok(())
243    }
244
245    /// Run the sequential optimization loop.
246    ///
247    /// # Arguments
248    /// * `objective` - Function to minimize.
249    /// * `n_iter` - Number of iterations (function evaluations after initial design).
250    pub fn optimize<F>(&mut self, objective: F, n_iter: usize) -> OptimizeResult<BayesianOptResult>
251    where
252        F: Fn(&ArrayView1<f64>) -> f64,
253    {
254        // Phase 1: Initial design
255        let n_initial = if self.observations.is_empty() {
256            self.config.n_initial
257        } else {
258            // If warm-started, may need fewer initial points
259            self.config
260                .n_initial
261                .saturating_sub(self.observations.len())
262        };
263
264        if n_initial > 0 {
265            let sampling_config = SamplingConfig {
266                seed: Some(self.rng.random()),
267                ..Default::default()
268            };
269            let initial_points = generate_samples(
270                n_initial,
271                &self.bounds,
272                self.config.initial_design,
273                Some(sampling_config),
274            )?;
275
276            for i in 0..initial_points.nrows() {
277                let x = initial_points.row(i).to_owned();
278                let y = objective(&x.view());
279                self.record_observation(x, y);
280            }
281
282            self.fit_surrogate()?;
283        }
284
285        let mut best_history = Vec::with_capacity(n_iter);
286        if let Some(best_idx) = self.best_idx {
287            best_history.push(self.observations[best_idx].y);
288        }
289
290        // Phase 2: Sequential optimization
291        for _iter in 0..n_iter {
292            let next_x = self.suggest_next()?;
293            let y = objective(&next_x.view());
294            self.record_observation(next_x, y);
295            self.fit_surrogate()?;
296
297            if let Some(best_idx) = self.best_idx {
298                best_history.push(self.observations[best_idx].y);
299            }
300        }
301
302        // Build result
303        let best_idx = self.best_idx.ok_or_else(|| {
304            OptimizeError::ComputationError("No observations collected".to_string())
305        })?;
306        let best_obs = &self.observations[best_idx];
307
308        Ok(BayesianOptResult {
309            x_best: best_obs.x.clone(),
310            f_best: best_obs.y,
311            observations: self.observations.clone(),
312            n_evals: self.observations.len(),
313            best_history,
314            success: true,
315            message: format!(
316                "Optimization completed: {} evaluations, best f = {:.6e}",
317                self.observations.len(),
318                best_obs.y
319            ),
320        })
321    }
322
323    /// Run batch optimization, evaluating `batch_size` points in parallel per round.
324    ///
325    /// Uses the Kriging Believer strategy: after selecting a candidate,
326    /// the GP is updated with a fantasised observation at the predicted mean.
327    pub fn optimize_batch<F>(
328        &mut self,
329        objective: F,
330        n_rounds: usize,
331        batch_size: usize,
332    ) -> OptimizeResult<BayesianOptResult>
333    where
334        F: Fn(&ArrayView1<f64>) -> f64,
335    {
336        let batch_size = batch_size.max(1);
337
338        // Phase 1: Initial design (same as sequential)
339        let n_initial = if self.observations.is_empty() {
340            self.config.n_initial
341        } else {
342            self.config
343                .n_initial
344                .saturating_sub(self.observations.len())
345        };
346
347        if n_initial > 0 {
348            let sampling_config = SamplingConfig {
349                seed: Some(self.rng.random()),
350                ..Default::default()
351            };
352            let initial_points = generate_samples(
353                n_initial,
354                &self.bounds,
355                self.config.initial_design,
356                Some(sampling_config),
357            )?;
358
359            for i in 0..initial_points.nrows() {
360                let x = initial_points.row(i).to_owned();
361                let y = objective(&x.view());
362                self.record_observation(x, y);
363            }
364            self.fit_surrogate()?;
365        }
366
367        let mut best_history = Vec::with_capacity(n_rounds);
368        if let Some(best_idx) = self.best_idx {
369            best_history.push(self.observations[best_idx].y);
370        }
371
372        // Phase 2: Batch optimization rounds
373        for _round in 0..n_rounds {
374            let batch = self.suggest_batch(batch_size)?;
375
376            // Evaluate all batch points
377            for x in &batch {
378                let y = objective(&x.view());
379                self.record_observation(x.clone(), y);
380            }
381
382            self.fit_surrogate()?;
383
384            if let Some(best_idx) = self.best_idx {
385                best_history.push(self.observations[best_idx].y);
386            }
387        }
388
389        let best_idx = self.best_idx.ok_or_else(|| {
390            OptimizeError::ComputationError("No observations collected".to_string())
391        })?;
392        let best_obs = &self.observations[best_idx];
393
394        Ok(BayesianOptResult {
395            x_best: best_obs.x.clone(),
396            f_best: best_obs.y,
397            observations: self.observations.clone(),
398            n_evals: self.observations.len(),
399            best_history,
400            success: true,
401            message: format!(
402                "Batch optimization completed: {} evaluations, best f = {:.6e}",
403                self.observations.len(),
404                best_obs.y
405            ),
406        })
407    }
408
409    /// Multi-objective optimization via ParEGO scalarization.
410    ///
411    /// Uses random weight vectors to scalarise the objectives into a single
412    /// augmented Chebyshev function, then runs standard BO on the scalarization.
413    ///
414    /// # Arguments
415    /// * `objectives` - Vector of objective functions to minimize.
416    /// * `n_iter` - Number of sequential iterations.
417    pub fn optimize_multi_objective<F>(
418        &mut self,
419        objectives: &[F],
420        n_iter: usize,
421    ) -> OptimizeResult<BayesianOptResult>
422    where
423        F: Fn(&ArrayView1<f64>) -> f64,
424    {
425        if objectives.is_empty() {
426            return Err(OptimizeError::InvalidInput(
427                "At least one objective is required".to_string(),
428            ));
429        }
430        if objectives.len() == 1 {
431            // Single objective: delegate to standard optimize
432            return self.optimize(&objectives[0], n_iter);
433        }
434
435        let n_obj = objectives.len();
436
437        // Phase 1: Initial design
438        let n_initial = if self.observations.is_empty() {
439            self.config.n_initial
440        } else {
441            self.config
442                .n_initial
443                .saturating_sub(self.observations.len())
444        };
445
446        // Store all objective values for normalization
447        let mut all_obj_values: Vec<Vec<f64>> = vec![Vec::new(); n_obj];
448
449        if n_initial > 0 {
450            let sampling_config = SamplingConfig {
451                seed: Some(self.rng.random()),
452                ..Default::default()
453            };
454            let initial_points = generate_samples(
455                n_initial,
456                &self.bounds,
457                self.config.initial_design,
458                Some(sampling_config),
459            )?;
460
461            for i in 0..initial_points.nrows() {
462                let x = initial_points.row(i).to_owned();
463                let obj_vals: Vec<f64> = objectives.iter().map(|f| f(&x.view())).collect();
464
465                // ParEGO scalarization with uniform weight (initial)
466                let scalarized = parego_scalarize(&obj_vals, &vec![1.0 / n_obj as f64; n_obj]);
467                self.record_observation(x, scalarized);
468
469                for (k, &v) in obj_vals.iter().enumerate() {
470                    all_obj_values[k].push(v);
471                }
472            }
473            self.fit_surrogate()?;
474        }
475
476        let mut best_history = Vec::new();
477        if let Some(best_idx) = self.best_idx {
478            best_history.push(self.observations[best_idx].y);
479        }
480
481        // Phase 2: Sequential iterations with rotating random weights
482        for _iter in 0..n_iter {
483            // Generate random weight vector on the simplex
484            let weights = random_simplex_point(n_obj, &mut self.rng);
485
486            // Suggest next point (based on current scalarized GP)
487            let next_x = self.suggest_next()?;
488
489            // Evaluate all objectives
490            let obj_vals: Vec<f64> = objectives.iter().map(|f| f(&next_x.view())).collect();
491            for (k, &v) in obj_vals.iter().enumerate() {
492                all_obj_values[k].push(v);
493            }
494
495            // Normalize and scalarize
496            let normalized: Vec<f64> = (0..n_obj)
497                .map(|k| {
498                    let vals = &all_obj_values[k];
499                    let min_v = vals.iter().copied().fold(f64::INFINITY, f64::min);
500                    let max_v = vals.iter().copied().fold(f64::NEG_INFINITY, f64::max);
501                    let range = (max_v - min_v).max(1e-12);
502                    (obj_vals[k] - min_v) / range
503                })
504                .collect();
505
506            let scalarized = parego_scalarize(&normalized, &weights);
507            self.record_observation(next_x, scalarized);
508            self.fit_surrogate()?;
509
510            if let Some(best_idx) = self.best_idx {
511                best_history.push(self.observations[best_idx].y);
512            }
513        }
514
515        let best_idx = self.best_idx.ok_or_else(|| {
516            OptimizeError::ComputationError("No observations collected".to_string())
517        })?;
518        let best_obs = &self.observations[best_idx];
519
520        Ok(BayesianOptResult {
521            x_best: best_obs.x.clone(),
522            f_best: best_obs.y,
523            observations: self.observations.clone(),
524            n_evals: self.observations.len(),
525            best_history,
526            success: true,
527            message: format!(
528                "ParEGO multi-objective optimization completed: {} evaluations",
529                self.observations.len()
530            ),
531        })
532    }
533
534    /// Get the ask interface: suggest the next point to evaluate.
535    pub fn ask(&mut self) -> OptimizeResult<Array1<f64>> {
536        if self.observations.is_empty() || self.observations.len() < self.config.n_initial {
537            // Still in initial design phase
538            let sampling_config = SamplingConfig {
539                seed: Some(self.rng.random()),
540                ..Default::default()
541            };
542            let points = generate_samples(
543                1,
544                &self.bounds,
545                self.config.initial_design,
546                Some(sampling_config),
547            )?;
548            Ok(points.row(0).to_owned())
549        } else {
550            self.suggest_next()
551        }
552    }
553
554    /// Tell interface: update with an observation.
555    pub fn tell(&mut self, x: Array1<f64>, y: f64) -> OptimizeResult<()> {
556        self.record_observation(x, y);
557        if self.observations.len() >= 2 {
558            self.fit_surrogate()?;
559        }
560        Ok(())
561    }
562
563    /// Get the current best observation.
564    pub fn best(&self) -> Option<&Observation> {
565        self.best_idx.map(|i| &self.observations[i])
566    }
567
568    /// Get all observations.
569    pub fn observations(&self) -> &[Observation] {
570        &self.observations
571    }
572
573    /// Number of observations.
574    pub fn n_observations(&self) -> usize {
575        self.observations.len()
576    }
577
578    /// Get reference to the GP surrogate.
579    pub fn surrogate(&self) -> &GpSurrogate {
580        &self.surrogate
581    }
582
583    // -----------------------------------------------------------------------
584    // Internal methods
585    // -----------------------------------------------------------------------
586
587    /// Record an observation and update the best index.
588    fn record_observation(&mut self, x: Array1<f64>, y: f64) {
589        let feasible = self.evaluate_constraints(&x);
590
591        let obs = Observation {
592            x,
593            y,
594            constraints: Vec::new(), // filled below if needed
595            feasible,
596        };
597
598        let idx = self.observations.len();
599
600        // Update best (prefer feasible solutions)
601        match self.best_idx {
602            Some(best) => {
603                let cur_best = &self.observations[best];
604                let new_is_better = if obs.feasible && !cur_best.feasible {
605                    true
606                } else if obs.feasible == cur_best.feasible {
607                    obs.y < cur_best.y
608                } else {
609                    false
610                };
611                if new_is_better {
612                    self.best_idx = Some(idx);
613                }
614            }
615            None => {
616                self.best_idx = Some(idx);
617            }
618        }
619
620        self.observations.push(obs);
621    }
622
623    /// Evaluate constraints for a point; returns true if all constraints are satisfied.
624    fn evaluate_constraints(&self, x: &Array1<f64>) -> bool {
625        self.constraints.iter().all(|c| (c.func)(&x.view()) <= 0.0)
626    }
627
628    /// Fit or refit the GP surrogate on all observations.
629    fn fit_surrogate(&mut self) -> OptimizeResult<()> {
630        let n = self.observations.len();
631        if n == 0 {
632            return Ok(());
633        }
634        let n_dims = self.observations[0].x.len();
635
636        let mut x_data = Array2::zeros((n, n_dims));
637        let mut y_data = Array1::zeros(n);
638
639        for (i, obs) in self.observations.iter().enumerate() {
640            for j in 0..n_dims {
641                x_data[[i, j]] = obs.x[j];
642            }
643            y_data[i] = obs.y;
644        }
645
646        self.surrogate.fit(&x_data, &y_data)
647    }
648
649    /// Suggest the next point to evaluate by optimising the acquisition function.
650    fn suggest_next(&mut self) -> OptimizeResult<Array1<f64>> {
651        let f_best = self.best_idx.map(|i| self.observations[i].y).unwrap_or(0.0);
652
653        // Build reference points for KG if needed
654        let n = self.observations.len();
655        let n_dims = self.bounds.len();
656        let ref_points = if n > 0 {
657            let mut pts = Array2::zeros((n, n_dims));
658            for (i, obs) in self.observations.iter().enumerate() {
659                for j in 0..n_dims {
660                    pts[[i, j]] = obs.x[j];
661                }
662            }
663            Some(pts)
664        } else {
665            None
666        };
667
668        let acq = self.config.acquisition.build(f_best, ref_points.as_ref());
669
670        self.optimize_acquisition(acq.as_ref())
671    }
672
673    /// Suggest a batch of points using the Kriging Believer strategy.
674    fn suggest_batch(&mut self, batch_size: usize) -> OptimizeResult<Vec<Array1<f64>>> {
675        let mut batch = Vec::with_capacity(batch_size);
676
677        for _ in 0..batch_size {
678            let next = self.suggest_next()?;
679
680            // Fantasy: predict mean at the selected point and add it as a phantom observation
681            let (mu, _sigma) = self.surrogate.predict_single(&next.view())?;
682            self.record_observation(next.clone(), mu);
683            self.fit_surrogate()?;
684
685            batch.push(next);
686        }
687
688        // Remove the phantom observations (they will be replaced with real ones)
689        let n_real = self.observations.len() - batch_size;
690        self.observations.truncate(n_real);
691
692        // Refit surrogate without phantoms
693        if !self.observations.is_empty() {
694            // Update best_idx in case we removed the best
695            self.best_idx = None;
696            for (i, obs) in self.observations.iter().enumerate() {
697                match self.best_idx {
698                    Some(best) if obs.y < self.observations[best].y => {
699                        self.best_idx = Some(i);
700                    }
701                    None => {
702                        self.best_idx = Some(i);
703                    }
704                    _ => {}
705                }
706            }
707            self.fit_surrogate()?;
708        }
709
710        Ok(batch)
711    }
712
713    /// Optimise the acquisition function over the search space.
714    ///
715    /// Uses random sampling + local refinement (coordinate search).
716    fn optimize_acquisition(&mut self, acq: &dyn AcquisitionFn) -> OptimizeResult<Array1<f64>> {
717        let n_dims = self.bounds.len();
718        let n_candidates = self.config.acq_n_candidates;
719        let n_restarts = self.config.acq_n_restarts;
720
721        // Generate random candidates
722        let sampling_config = SamplingConfig {
723            seed: Some(self.rng.random()),
724            ..Default::default()
725        };
726        let candidates = generate_samples(
727            n_candidates,
728            &self.bounds,
729            SamplingStrategy::Random,
730            Some(sampling_config),
731        )?;
732
733        // Also include the current best as a candidate
734        let mut best_x = candidates.row(0).to_owned();
735        let mut best_val = f64::NEG_INFINITY;
736
737        // Evaluate all candidates
738        for i in 0..candidates.nrows() {
739            match acq.evaluate(&candidates.row(i), &self.surrogate) {
740                Ok(val) if val > best_val => {
741                    best_val = val;
742                    best_x = candidates.row(i).to_owned();
743                }
744                _ => {}
745            }
746        }
747
748        // If we have a current best observation, add it as a candidate
749        if let Some(best_idx) = self.best_idx {
750            let obs_x = &self.observations[best_idx].x;
751            if let Ok(val) = acq.evaluate(&obs_x.view(), &self.surrogate) {
752                if val > best_val {
753                    best_val = val;
754                    best_x = obs_x.clone();
755                }
756            }
757        }
758
759        // Local refinement: coordinate-wise search from the top-n candidates
760        // Collect top candidates
761        let mut scored: Vec<(f64, usize)> = Vec::new();
762        for i in 0..candidates.nrows() {
763            if let Ok(val) = acq.evaluate(&candidates.row(i), &self.surrogate) {
764                scored.push((val, i));
765            }
766        }
767        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
768
769        let n_refine = n_restarts.min(scored.len());
770        for k in 0..n_refine {
771            let mut x_current = candidates.row(scored[k].1).to_owned();
772            let mut f_current = scored[k].0;
773
774            // Coordinate-wise golden section search
775            for _round in 0..3 {
776                for d in 0..n_dims {
777                    let (lo, hi) = self.bounds[d];
778                    let (refined_x, refined_f) =
779                        golden_section_1d(acq, &self.surrogate, &x_current, d, lo, hi, 20)?;
780                    if refined_f > f_current {
781                        x_current[d] = refined_x;
782                        f_current = refined_f;
783                    }
784                }
785            }
786
787            if f_current > best_val {
788                best_val = f_current;
789                best_x = x_current;
790            }
791        }
792
793        // Clamp to bounds
794        for (d, &(lo, hi)) in self.bounds.iter().enumerate() {
795            best_x[d] = best_x[d].clamp(lo, hi);
796        }
797
798        Ok(best_x)
799    }
800}
801
802// ---------------------------------------------------------------------------
803// Helper functions
804// ---------------------------------------------------------------------------
805
806/// Golden section search for maximising `acq(x_base with dim d = t)` over [lo, hi].
807fn golden_section_1d(
808    acq: &dyn AcquisitionFn,
809    surrogate: &GpSurrogate,
810    x_base: &Array1<f64>,
811    dim: usize,
812    lo: f64,
813    hi: f64,
814    max_iters: usize,
815) -> OptimizeResult<(f64, f64)> {
816    let gr = (5.0_f64.sqrt() - 1.0) / 2.0; // golden ratio conjugate
817    let mut a = lo;
818    let mut b = hi;
819
820    let eval_at = |t: f64| -> OptimizeResult<f64> {
821        let mut x = x_base.clone();
822        x[dim] = t;
823        acq.evaluate(&x.view(), surrogate)
824    };
825
826    let mut c = b - gr * (b - a);
827    let mut d = a + gr * (b - a);
828    let mut fc = eval_at(c)?;
829    let mut fd = eval_at(d)?;
830
831    for _ in 0..max_iters {
832        if (b - a).abs() < 1e-8 {
833            break;
834        }
835        // We want to maximise, so we keep the side with the larger value
836        if fc < fd {
837            a = c;
838            c = d;
839            fc = fd;
840            d = a + gr * (b - a);
841            fd = eval_at(d)?;
842        } else {
843            b = d;
844            d = c;
845            fd = fc;
846            c = b - gr * (b - a);
847            fc = eval_at(c)?;
848        }
849    }
850
851    let mid = (a + b) / 2.0;
852    let f_mid = eval_at(mid)?;
853    Ok((mid, f_mid))
854}
855
856/// ParEGO augmented Chebyshev scalarization.
857///
858/// s(f, w) = max_k { w_k * f_k } + rho * sum_k { w_k * f_k }
859///
860/// where rho = 0.05 is a small augmentation coefficient.
861fn parego_scalarize(obj_values: &[f64], weights: &[f64]) -> f64 {
862    let rho = 0.05;
863    let mut max_wf = f64::NEG_INFINITY;
864    let mut sum_wf = 0.0;
865
866    for (k, (&fk, &wk)) in obj_values.iter().zip(weights.iter()).enumerate() {
867        let wf = wk * fk;
868        if wf > max_wf {
869            max_wf = wf;
870        }
871        sum_wf += wf;
872    }
873
874    max_wf + rho * sum_wf
875}
876
877/// Generate a random point on the probability simplex using the Dirichlet trick.
878fn random_simplex_point(n: usize, rng: &mut StdRng) -> Vec<f64> {
879    if n == 0 {
880        return Vec::new();
881    }
882    if n == 1 {
883        return vec![1.0];
884    }
885
886    // Sample from Exp(1) and normalize
887    let mut values: Vec<f64> = (0..n)
888        .map(|_| {
889            let u: f64 = rng.random_range(1e-10..1.0);
890            -u.ln()
891        })
892        .collect();
893
894    let sum: f64 = values.iter().sum();
895    if sum > 0.0 {
896        for v in &mut values {
897            *v /= sum;
898        }
899    } else {
900        // Fallback to uniform
901        let w = 1.0 / n as f64;
902        values.fill(w);
903    }
904    values
905}
906
907// ---------------------------------------------------------------------------
908// Convenience function
909// ---------------------------------------------------------------------------
910
911/// Run Bayesian optimization on a function.
912///
913/// This is a high-level convenience function that creates a `BayesianOptimizer`,
914/// runs the optimization, and returns the result.
915///
916/// # Arguments
917/// * `objective` - Function to minimize: `f(x) -> f64`
918/// * `bounds` - Search bounds: `[(lo, hi), ...]`
919/// * `n_iter` - Number of sequential iterations (after initial design)
920/// * `config` - Optional optimizer configuration
921///
922/// # Example
923///
924/// ```rust
925/// use scirs2_optimize::bayesian::optimize;
926/// use scirs2_core::ndarray::ArrayView1;
927///
928/// let result = optimize(
929///     |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2),
930///     &[(-5.0, 5.0), (-5.0, 5.0)],
931///     20,
932///     None,
933/// ).expect("optimization failed");
934///
935/// assert!(result.f_best < 1.0);
936/// ```
937pub fn optimize<F>(
938    objective: F,
939    bounds: &[(f64, f64)],
940    n_iter: usize,
941    config: Option<BayesianOptimizerConfig>,
942) -> OptimizeResult<BayesianOptResult>
943where
944    F: Fn(&ArrayView1<f64>) -> f64,
945{
946    let config = config.unwrap_or_default();
947    let mut optimizer = BayesianOptimizer::new(bounds.to_vec(), config)?;
948    optimizer.optimize(objective, n_iter)
949}
950
951// ---------------------------------------------------------------------------
952// Tests
953// ---------------------------------------------------------------------------
954
955#[cfg(test)]
956mod tests {
957    use super::*;
958    use scirs2_core::ndarray::array;
959
960    fn sphere(x: &ArrayView1<f64>) -> f64 {
961        x.iter().map(|&v| v * v).sum()
962    }
963
964    fn rosenbrock_2d(x: &ArrayView1<f64>) -> f64 {
965        (1.0 - x[0]).powi(2) + 100.0 * (x[1] - x[0].powi(2)).powi(2)
966    }
967
968    #[test]
969    fn test_optimize_sphere_2d() {
970        let config = BayesianOptimizerConfig {
971            n_initial: 8,
972            seed: Some(42),
973            gp_config: GpSurrogateConfig {
974                optimize_hyperparams: false,
975                noise_variance: 1e-4,
976                ..Default::default()
977            },
978            ..Default::default()
979        };
980        let result = optimize(sphere, &[(-5.0, 5.0), (-5.0, 5.0)], 25, Some(config))
981            .expect("optimization should succeed");
982
983        assert!(result.success);
984        assert!(result.f_best < 2.0, "f_best = {:.4}", result.f_best);
985    }
986
987    #[test]
988    fn test_optimizer_ask_tell() {
989        let config = BayesianOptimizerConfig {
990            n_initial: 5,
991            seed: Some(42),
992            gp_config: GpSurrogateConfig {
993                optimize_hyperparams: false,
994                noise_variance: 1e-4,
995                ..Default::default()
996            },
997            ..Default::default()
998        };
999        let mut opt =
1000            BayesianOptimizer::new(vec![(-5.0, 5.0), (-5.0, 5.0)], config).expect("create ok");
1001
1002        for _ in 0..15 {
1003            let x = opt.ask().expect("ask ok");
1004            let y = sphere(&x.view());
1005            opt.tell(x, y).expect("tell ok");
1006        }
1007
1008        let best = opt.best().expect("should have a best");
1009        assert!(best.y < 5.0, "best y = {:.4}", best.y);
1010    }
1011
1012    #[test]
1013    fn test_warm_start() {
1014        let config = BayesianOptimizerConfig {
1015            n_initial: 3,
1016            seed: Some(42),
1017            gp_config: GpSurrogateConfig {
1018                optimize_hyperparams: false,
1019                noise_variance: 1e-4,
1020                ..Default::default()
1021            },
1022            ..Default::default()
1023        };
1024        let mut opt =
1025            BayesianOptimizer::new(vec![(-5.0, 5.0), (-5.0, 5.0)], config).expect("create ok");
1026
1027        // Warm start with some previous data
1028        let x_prev =
1029            Array2::from_shape_vec((3, 2), vec![0.1, 0.2, -0.3, 0.1, 0.5, -0.5]).expect("shape ok");
1030        let y_prev = array![0.05, 0.1, 0.5];
1031        opt.warm_start(&x_prev, &y_prev).expect("warm start ok");
1032
1033        assert_eq!(opt.n_observations(), 3);
1034
1035        let result = opt.optimize(sphere, 10).expect("optimize ok");
1036        assert!(result.f_best < 0.5);
1037    }
1038
1039    #[test]
1040    fn test_batch_optimization() {
1041        let config = BayesianOptimizerConfig {
1042            n_initial: 5,
1043            seed: Some(42),
1044            gp_config: GpSurrogateConfig {
1045                optimize_hyperparams: false,
1046                noise_variance: 1e-4,
1047                ..Default::default()
1048            },
1049            ..Default::default()
1050        };
1051        let mut opt =
1052            BayesianOptimizer::new(vec![(-5.0, 5.0), (-5.0, 5.0)], config).expect("create ok");
1053
1054        let result = opt
1055            .optimize_batch(sphere, 5, 3)
1056            .expect("batch optimization ok");
1057        assert!(result.success);
1058        // 5 initial + 5*3 = 20 total evaluations
1059        assert_eq!(result.n_evals, 20);
1060    }
1061
1062    #[test]
1063    fn test_constrained_optimization() {
1064        let config = BayesianOptimizerConfig {
1065            n_initial: 8,
1066            seed: Some(42),
1067            gp_config: GpSurrogateConfig {
1068                optimize_hyperparams: false,
1069                noise_variance: 1e-4,
1070                ..Default::default()
1071            },
1072            ..Default::default()
1073        };
1074        let mut opt =
1075            BayesianOptimizer::new(vec![(-5.0, 5.0), (-5.0, 5.0)], config).expect("create ok");
1076
1077        // Constraint: x[0] >= 1.0 (i.e., 1.0 - x[0] <= 0)
1078        opt.add_constraint("x0_ge_1", |x: &ArrayView1<f64>| 1.0 - x[0]);
1079
1080        let result = opt.optimize(sphere, 20).expect("optimize ok");
1081        // The constrained minimum of x^2+y^2 with x >= 1 is at (1,0), f=1
1082        // We just check the optimizer found something feasible and reasonable
1083        assert!(result.success);
1084        assert!(result.x_best[0] >= 0.5, "x[0] should be near >= 1");
1085    }
1086
1087    #[test]
1088    fn test_multi_objective_parego() {
1089        let config = BayesianOptimizerConfig {
1090            n_initial: 8,
1091            seed: Some(42),
1092            gp_config: GpSurrogateConfig {
1093                optimize_hyperparams: false,
1094                noise_variance: 1e-4,
1095                ..Default::default()
1096            },
1097            ..Default::default()
1098        };
1099        let mut opt =
1100            BayesianOptimizer::new(vec![(-5.0, 5.0), (-5.0, 5.0)], config).expect("create ok");
1101
1102        // Two objectives: f1 = (x-1)^2 + y^2, f2 = (x+1)^2 + y^2
1103        let f1 = |x: &ArrayView1<f64>| (x[0] - 1.0).powi(2) + x[1].powi(2);
1104        let f2 = |x: &ArrayView1<f64>| (x[0] + 1.0).powi(2) + x[1].powi(2);
1105        let objectives: Vec<Box<dyn Fn(&ArrayView1<f64>) -> f64>> =
1106            vec![Box::new(f1), Box::new(f2)];
1107
1108        let obj_refs: Vec<&dyn Fn(&ArrayView1<f64>) -> f64> = objectives
1109            .iter()
1110            .map(|f| f.as_ref() as &dyn Fn(&ArrayView1<f64>) -> f64)
1111            .collect();
1112
1113        // Need to pass as slice of Fn
1114        let result = opt
1115            .optimize_multi_objective(&obj_refs[..], 15)
1116            .expect("multi-objective ok");
1117        assert!(result.success);
1118        // The Pareto front is between x=-1 and x=1
1119        assert!(result.x_best[0].abs() <= 5.0);
1120    }
1121
1122    #[test]
1123    fn test_different_acquisition_functions() {
1124        let bounds = vec![(-3.0, 3.0)];
1125
1126        for acq in &[
1127            AcquisitionType::EI { xi: 0.01 },
1128            AcquisitionType::PI { xi: 0.01 },
1129            AcquisitionType::UCB { kappa: 2.0 },
1130            AcquisitionType::Thompson { seed: 42 },
1131        ] {
1132            let config = BayesianOptimizerConfig {
1133                acquisition: acq.clone(),
1134                n_initial: 5,
1135                seed: Some(42),
1136                gp_config: GpSurrogateConfig {
1137                    optimize_hyperparams: false,
1138                    noise_variance: 1e-4,
1139                    ..Default::default()
1140                },
1141                ..Default::default()
1142            };
1143            let result = optimize(
1144                |x: &ArrayView1<f64>| x[0].powi(2),
1145                &bounds,
1146                10,
1147                Some(config),
1148            )
1149            .expect("optimize ok");
1150            assert!(
1151                result.f_best < 3.0,
1152                "Acquisition {:?} failed: f_best = {}",
1153                acq,
1154                result.f_best
1155            );
1156        }
1157    }
1158
1159    #[test]
1160    fn test_invalid_bounds_rejected() {
1161        let result = BayesianOptimizer::new(
1162            vec![(5.0, 1.0)], // lo > hi
1163            BayesianOptimizerConfig::default(),
1164        );
1165        assert!(result.is_err());
1166    }
1167
1168    #[test]
1169    fn test_empty_bounds_rejected() {
1170        let result = BayesianOptimizer::new(vec![], BayesianOptimizerConfig::default());
1171        assert!(result.is_err());
1172    }
1173
1174    #[test]
1175    fn test_best_history_monotonic() {
1176        let config = BayesianOptimizerConfig {
1177            n_initial: 5,
1178            seed: Some(42),
1179            gp_config: GpSurrogateConfig {
1180                optimize_hyperparams: false,
1181                noise_variance: 1e-4,
1182                ..Default::default()
1183            },
1184            ..Default::default()
1185        };
1186        let result =
1187            optimize(sphere, &[(-5.0, 5.0), (-5.0, 5.0)], 10, Some(config)).expect("optimize ok");
1188
1189        // Best history should be non-increasing
1190        for i in 1..result.best_history.len() {
1191            assert!(
1192                result.best_history[i] <= result.best_history[i - 1] + 1e-12,
1193                "Best history not monotonic at index {}: {} > {}",
1194                i,
1195                result.best_history[i],
1196                result.best_history[i - 1]
1197            );
1198        }
1199    }
1200
1201    #[test]
1202    fn test_parego_scalarize() {
1203        let obj = [0.3, 0.7];
1204        let w = [0.5, 0.5];
1205        let s = parego_scalarize(&obj, &w);
1206        // max(0.15, 0.35) + 0.05 * (0.15 + 0.35) = 0.35 + 0.025 = 0.375
1207        assert!((s - 0.375).abs() < 1e-10);
1208    }
1209
1210    #[test]
1211    fn test_random_simplex_point_sums_to_one() {
1212        let mut rng = StdRng::seed_from_u64(42);
1213        for n in 1..6 {
1214            let pt = random_simplex_point(n, &mut rng);
1215            assert_eq!(pt.len(), n);
1216            let sum: f64 = pt.iter().sum();
1217            assert!((sum - 1.0).abs() < 1e-10, "Simplex sum = {}", sum);
1218            for &v in &pt {
1219                assert!(v >= 0.0, "Simplex component negative: {}", v);
1220            }
1221        }
1222    }
1223
1224    #[test]
1225    fn test_optimize_1d() {
1226        let config = BayesianOptimizerConfig {
1227            n_initial: 5,
1228            seed: Some(42),
1229            gp_config: GpSurrogateConfig {
1230                optimize_hyperparams: false,
1231                noise_variance: 1e-4,
1232                ..Default::default()
1233            },
1234            ..Default::default()
1235        };
1236        let result = optimize(
1237            |x: &ArrayView1<f64>| (x[0] - 2.0).powi(2),
1238            &[(-5.0, 5.0)],
1239            15,
1240            Some(config),
1241        )
1242        .expect("optimize ok");
1243
1244        assert!(
1245            (result.x_best[0] - 2.0).abs() < 1.5,
1246            "x_best = {:.4}, expected ~2.0",
1247            result.x_best[0]
1248        );
1249        assert!(result.f_best < 2.0);
1250    }
1251}