Skip to main content

oxiphysics_core/
bayesian_opt.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Bayesian optimization with Gaussian process surrogates.
5//!
6//! Provides a full Bayesian optimization loop: fit a GP surrogate, evaluate
7//! an acquisition function to pick the next candidate, observe the objective,
8//! and iterate.  Supports RBF, Matérn-5/2, and Periodic kernels.
9
10#![allow(dead_code)]
11#![allow(clippy::too_many_arguments)]
12
13use std::f64::consts::{PI, SQRT_2};
14
15// ---------------------------------------------------------------------------
16// Kernel
17// ---------------------------------------------------------------------------
18
19/// The covariance kernel used by the Gaussian process.
20#[derive(Debug, Clone, Copy, PartialEq)]
21pub enum KernelType {
22    /// Radial basis function (squared-exponential) kernel.
23    Rbf,
24    /// Matérn 5/2 kernel.
25    Matern52,
26    /// Periodic kernel.
27    Periodic,
28}
29
30/// Hyper-parameters for the GP kernel and likelihood noise.
31#[derive(Debug, Clone)]
32pub struct KernelParams {
33    /// Signal variance (amplitude squared).
34    pub amplitude: f64,
35    /// Length scale controlling smoothness.
36    pub length_scale: f64,
37    /// Observation noise variance added to the diagonal.
38    pub noise_variance: f64,
39    /// Period for the `Periodic` kernel (ignored by other kernels).
40    pub period: f64,
41}
42
43impl Default for KernelParams {
44    fn default() -> Self {
45        Self {
46            amplitude: 1.0,
47            length_scale: 1.0,
48            noise_variance: 1e-4,
49            period: 1.0,
50        }
51    }
52}
53
54/// Evaluate the kernel between two input vectors `a` and `b`.
55///
56/// Returns the scalar covariance `k(a, b)`.
57pub fn kernel_eval(kt: KernelType, params: &KernelParams, a: &[f64], b: &[f64]) -> f64 {
58    debug_assert_eq!(a.len(), b.len());
59    let sq_dist: f64 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
60    let dist = sq_dist.sqrt();
61    let l = params.length_scale.max(1e-12);
62    let amp2 = params.amplitude * params.amplitude;
63    match kt {
64        KernelType::Rbf => amp2 * (-0.5 * sq_dist / (l * l)).exp(),
65        KernelType::Matern52 => {
66            let r = SQRT_2 * 5_f64.sqrt() * dist / l;
67            amp2 * (1.0 + r + r * r / 3.0) * (-r).exp()
68        }
69        KernelType::Periodic => {
70            let sin_arg = PI * dist / params.period;
71            amp2 * (-2.0 * sin_arg.sin().powi(2) / (l * l)).exp()
72        }
73    }
74}
75
76// ---------------------------------------------------------------------------
77// Cholesky helpers (used by GP)
78// ---------------------------------------------------------------------------
79
80/// Compute the lower-triangular Cholesky factor `L` such that `A = L Lᵀ`.
81///
82/// `A` is stored row-major in a flat `Vec`f64` of length `n*n`.
83/// Returns `Err` if the matrix is not positive-definite.
84pub fn cholesky(a: &[f64], n: usize) -> Result<Vec<f64>, String> {
85    let mut l = vec![0.0_f64; n * n];
86    for i in 0..n {
87        for j in 0..=i {
88            let s: f64 = (0..j).map(|k| l[i * n + k] * l[j * n + k]).sum();
89            if i == j {
90                let val = a[i * n + i] - s;
91                if val < 0.0 {
92                    return Err(format!(
93                        "matrix is not positive-definite at diagonal ({i},{i})"
94                    ));
95                }
96                l[i * n + j] = val.sqrt();
97            } else {
98                let ljj = l[j * n + j];
99                if ljj.abs() < 1e-15 {
100                    return Err("near-zero diagonal in Cholesky".into());
101                }
102                l[i * n + j] = (a[i * n + j] - s) / ljj;
103            }
104        }
105    }
106    Ok(l)
107}
108
109/// Solve `L x = b` for `x` (forward substitution).  `L` is lower-triangular.
110fn solve_lower(l: &[f64], b: &[f64], n: usize) -> Vec<f64> {
111    let mut x = vec![0.0; n];
112    for i in 0..n {
113        let s: f64 = (0..i).map(|j| l[i * n + j] * x[j]).sum();
114        x[i] = (b[i] - s) / l[i * n + i];
115    }
116    x
117}
118
119/// Solve `Lᵀ x = b` for `x` (back substitution).  `L` is lower-triangular.
120fn solve_upper(l: &[f64], b: &[f64], n: usize) -> Vec<f64> {
121    let mut x = vec![0.0; n];
122    for i in (0..n).rev() {
123        let s: f64 = (i + 1..n).map(|j| l[j * n + i] * x[j]).sum();
124        x[i] = (b[i] - s) / l[i * n + i];
125    }
126    x
127}
128
129// ---------------------------------------------------------------------------
130// GaussianProcess
131// ---------------------------------------------------------------------------
132
133/// A Gaussian process regressor using a stationary kernel.
134///
135/// Training points are stored internally.  After [`GaussianProcess::fit`] the
136/// GP can predict the posterior mean and variance at arbitrary test points.
137#[derive(Debug, Clone)]
138pub struct GaussianProcess {
139    /// Kernel type.
140    pub kernel: KernelType,
141    /// Kernel hyper-parameters.
142    pub params: KernelParams,
143    /// Training inputs (each row is one sample).
144    x_train: Vec<Vec<f64>>,
145    /// Training targets.
146    y_train: Vec<f64>,
147    /// Cholesky factor of the covariance matrix (row-major, `n × n`).
148    chol: Vec<f64>,
149    /// α = K⁻¹ y (used for mean prediction).
150    alpha: Vec<f64>,
151    /// Whether the GP has been fitted.
152    fitted: bool,
153}
154
155impl GaussianProcess {
156    /// Construct a new, unfitted GP with the given kernel and parameters.
157    pub fn new(kernel: KernelType, params: KernelParams) -> Self {
158        Self {
159            kernel,
160            params,
161            x_train: Vec::new(),
162            y_train: Vec::new(),
163            chol: Vec::new(),
164            alpha: Vec::new(),
165            fitted: false,
166        }
167    }
168
169    /// Fit the GP to the given training data.
170    ///
171    /// # Panics
172    /// Panics if `x` and `y` have different lengths, or if `x` is empty.
173    pub fn fit(&mut self, x: Vec<Vec<f64>>, y: Vec<f64>) -> Result<(), String> {
174        assert_eq!(x.len(), y.len(), "x and y must have the same length");
175        assert!(!x.is_empty(), "Training set must be non-empty");
176
177        let n = x.len();
178        // Build the n×n kernel matrix + noise
179        let mut k = vec![0.0_f64; n * n];
180        for i in 0..n {
181            for j in 0..n {
182                let kval = kernel_eval(self.kernel, &self.params, &x[i], &x[j]);
183                k[i * n + j] = kval;
184            }
185            // Add noise to the diagonal
186            k[i * n + i] += self.params.noise_variance;
187        }
188
189        let l = cholesky(&k, n)?;
190        // Solve K α = y  →  L α' = y, then Lᵀ α = α'
191        let alpha_tmp = solve_lower(&l, &y, n);
192        let alpha = solve_upper(&l, &alpha_tmp, n);
193
194        self.x_train = x;
195        self.y_train = y;
196        self.chol = l;
197        self.alpha = alpha;
198        self.fitted = true;
199        Ok(())
200    }
201
202    /// Predict the posterior (mean, variance) at test point `x_star`.
203    ///
204    /// Returns `(mean, variance)`.
205    ///
206    /// # Panics
207    /// Panics if the GP has not been fitted.
208    pub fn predict(&self, x_star: &[f64]) -> (f64, f64) {
209        assert!(self.fitted, "GP must be fitted before calling predict");
210        let n = self.x_train.len();
211
212        // k_star = [k(x*, x_1), …, k(x*, x_n)]
213        let k_star: Vec<f64> = self
214            .x_train
215            .iter()
216            .map(|xi| kernel_eval(self.kernel, &self.params, x_star, xi))
217            .collect();
218
219        // mean = k_starᵀ α
220        let mean: f64 = k_star
221            .iter()
222            .zip(self.alpha.iter())
223            .map(|(a, b)| a * b)
224            .sum();
225
226        // var = k(x*, x*) - k_starᵀ K⁻¹ k_star
227        //     = k(x*, x*) - v·v  where v = L⁻¹ k_star
228        let k_ss = kernel_eval(self.kernel, &self.params, x_star, x_star);
229        let v = solve_lower(&self.chol, &k_star, n);
230        let var = (k_ss - v.iter().map(|vi| vi * vi).sum::<f64>()).max(0.0);
231
232        (mean, var)
233    }
234
235    /// Return the number of training points.
236    pub fn n_train(&self) -> usize {
237        self.x_train.len()
238    }
239
240    /// Return whether the GP has been fitted.
241    pub fn is_fitted(&self) -> bool {
242        self.fitted
243    }
244}
245
246// ---------------------------------------------------------------------------
247// Acquisition functions
248// ---------------------------------------------------------------------------
249
250/// Acquisition function used to select the next candidate point.
251#[derive(Debug, Clone, Copy, PartialEq)]
252pub enum AcquisitionFn {
253    /// Expected Improvement.
254    ///
255    /// Balances exploitation (improving over the current best) and
256    /// exploration (exploring uncertain regions).
257    ExpectedImprovement,
258    /// Upper Confidence Bound.
259    ///
260    /// Combines mean and standard deviation: `μ + κ σ`.
261    UpperConfidenceBound,
262    /// Probability of Improvement.
263    ///
264    /// Probability that the next point exceeds the current best.
265    ProbabilityOfImprovement,
266}
267
268/// Standard normal CDF Φ(z).
269fn standard_normal_cdf(z: f64) -> f64 {
270    0.5 * (1.0 + libm_erf(z / SQRT_2))
271}
272
273/// Standard normal PDF φ(z).
274fn standard_normal_pdf(z: f64) -> f64 {
275    (-0.5 * z * z).exp() / (2.0 * PI).sqrt()
276}
277
278/// Error function approximation (Abramowitz & Stegun 7.1.26, max error < 1.5e-7).
279fn libm_erf(x: f64) -> f64 {
280    // Use the sign symmetry erf(-x) = -erf(x)
281    let sign = if x < 0.0 { -1.0 } else { 1.0 };
282    let x = x.abs();
283    let t = 1.0 / (1.0 + 0.3275911 * x);
284    let poly = t
285        * (0.254829592
286            + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
287    sign * (1.0 - poly * (-x * x).exp())
288}
289
290/// Evaluate the acquisition function at a single point.
291///
292/// - `mean`, `var` are the GP posterior moments at the candidate.
293/// - `best_y`  is the best observed value so far (for EI and PI).
294/// - `kappa`   is the exploration weight for UCB.
295/// - `xi`      is the exploration–exploitation trade-off for EI/PI.
296pub fn acquisition_value(
297    acq: AcquisitionFn,
298    mean: f64,
299    var: f64,
300    best_y: f64,
301    kappa: f64,
302    xi: f64,
303) -> f64 {
304    let sigma = var.sqrt().max(1e-12);
305    match acq {
306        AcquisitionFn::ExpectedImprovement => {
307            let z = (mean - best_y - xi) / sigma;
308            (mean - best_y - xi) * standard_normal_cdf(z) + sigma * standard_normal_pdf(z)
309        }
310        AcquisitionFn::UpperConfidenceBound => mean + kappa * sigma,
311        AcquisitionFn::ProbabilityOfImprovement => {
312            let z = (mean - best_y - xi) / sigma;
313            standard_normal_cdf(z)
314        }
315    }
316}
317
318// ---------------------------------------------------------------------------
319// Latin-hypercube sampling
320// ---------------------------------------------------------------------------
321
322/// Generate a Latin hypercube sample of `n` points in a `dim`-dimensional box.
323///
324/// Each axis is divided into `n` equal intervals; one point is sampled from
325/// each interval per axis.  The result is a `Vec` of `n` points, each a
326/// `Vec`f64` of length `dim`.  The `bounds` slice must have length `dim`.
327pub fn latin_hypercube_sample(n: usize, dim: usize, bounds: &[(f64, f64)]) -> Vec<Vec<f64>> {
328    assert_eq!(bounds.len(), dim, "bounds length must equal dim");
329    if n == 0 || dim == 0 {
330        return Vec::new();
331    }
332
333    let mut rng = rand::rng();
334    use rand::RngExt as _;
335
336    // For each dimension, create a permuted sequence of interval midpoints
337    // then add a random jitter within each interval.
338    let mut samples = vec![vec![0.0_f64; dim]; n];
339
340    for d in 0..dim {
341        let (lo, hi) = bounds[d];
342        let interval = (hi - lo) / n as f64;
343
344        // Create indices 0..n and shuffle them (Fisher-Yates)
345        let mut order: Vec<usize> = (0..n).collect();
346        for i in (1..n).rev() {
347            let j = rng.random_range(0..=i);
348            order.swap(i, j);
349        }
350
351        for (i, &slot) in order.iter().enumerate() {
352            let base = lo + slot as f64 * interval;
353            let jitter = rng.random_range(0.0..interval);
354            samples[i][d] = base + jitter;
355        }
356    }
357    samples
358}
359
360// ---------------------------------------------------------------------------
361// BayesianOptimizer
362// ---------------------------------------------------------------------------
363
364/// Configuration for [`BayesianOptimizer`].
365#[derive(Debug, Clone)]
366pub struct BayesOpts {
367    /// Number of initial random points (LHS sample) before GP is used.
368    pub n_initial: usize,
369    /// Maximum number of optimization iterations.
370    pub max_iter: usize,
371    /// Number of random candidates evaluated to maximise the acquisition fn.
372    pub n_candidates: usize,
373    /// Acquisition function to use.
374    pub acquisition: AcquisitionFn,
375    /// UCB exploration weight κ.
376    pub kappa: f64,
377    /// EI/PI exploration offset ξ.
378    pub xi: f64,
379}
380
381impl Default for BayesOpts {
382    fn default() -> Self {
383        Self {
384            n_initial: 5,
385            max_iter: 20,
386            n_candidates: 512,
387            acquisition: AcquisitionFn::ExpectedImprovement,
388            kappa: 2.576,
389            xi: 0.01,
390        }
391    }
392}
393
394/// Bayesian optimization over a bounded box using a GP surrogate.
395///
396/// The optimizer maintains a GP fitted to all observations so far, and at
397/// each step proposes the point that maximises the acquisition function.
398#[derive(Debug, Clone)]
399pub struct BayesianOptimizer {
400    /// Search space bounds: one `(lo, hi)` per dimension.
401    pub bounds: Vec<(f64, f64)>,
402    /// GP surrogate.
403    pub gp: GaussianProcess,
404    /// All observed inputs.
405    pub x_observed: Vec<Vec<f64>>,
406    /// All observed outputs.
407    pub y_observed: Vec<f64>,
408    /// Best output value observed so far.
409    pub best_y: f64,
410    /// Input that yielded `best_y`.
411    pub best_x: Vec<f64>,
412    /// Optimizer configuration.
413    pub opts: BayesOpts,
414}
415
416impl BayesianOptimizer {
417    /// Construct a new optimizer.
418    ///
419    /// - `bounds` — axis-aligned box, one `(lo, hi)` per input dimension.
420    /// - `kernel` — kernel type for the GP surrogate.
421    /// - `params` — kernel hyper-parameters.
422    /// - `opts`   — algorithm options (number of iterations, acquisition, …).
423    pub fn new(
424        bounds: Vec<(f64, f64)>,
425        kernel: KernelType,
426        params: KernelParams,
427        opts: BayesOpts,
428    ) -> Self {
429        let gp = GaussianProcess::new(kernel, params);
430        Self {
431            bounds,
432            gp,
433            x_observed: Vec::new(),
434            y_observed: Vec::new(),
435            best_y: f64::NEG_INFINITY,
436            best_x: Vec::new(),
437            opts,
438        }
439    }
440
441    /// Incorporate a new observation `(x, y)` into the optimizer state.
442    ///
443    /// The GP surrogate is re-fitted after each call.
444    pub fn update(&mut self, x: Vec<f64>, y: f64) -> Result<(), String> {
445        if y > self.best_y {
446            self.best_y = y;
447            self.best_x = x.clone();
448        }
449        self.x_observed.push(x);
450        self.y_observed.push(y);
451
452        // Re-fit the GP
453        self.gp
454            .fit(self.x_observed.clone(), self.y_observed.clone())
455    }
456
457    /// Suggest the next candidate point to evaluate.
458    ///
459    /// Uses Latin-hypercube random candidates and picks the one with the
460    /// highest acquisition value.  Falls back to a random LHS point if the
461    /// GP has not been fitted yet.
462    pub fn suggest_next(&self) -> Vec<f64> {
463        let candidates =
464            latin_hypercube_sample(self.opts.n_candidates, self.bounds.len(), &self.bounds);
465
466        if !self.gp.is_fitted() {
467            // Before any GP fit, just return the first candidate
468            return candidates
469                .into_iter()
470                .next()
471                .unwrap_or_else(|| self.bounds.iter().map(|(lo, hi)| (lo + hi) / 2.0).collect());
472        }
473
474        let best_y = self.best_y;
475        let acq = self.opts.acquisition;
476        let kappa = self.opts.kappa;
477        let xi = self.opts.xi;
478
479        let mut best_acq = f64::NEG_INFINITY;
480        let mut best_candidate = candidates[0].clone();
481
482        for cand in &candidates {
483            let (mean, var) = self.gp.predict(cand);
484            let val = acquisition_value(acq, mean, var, best_y, kappa, xi);
485            if val > best_acq {
486                best_acq = val;
487                best_candidate = cand.clone();
488            }
489        }
490        best_candidate
491    }
492
493    /// Run the full optimization loop, evaluating the black-box `f`.
494    ///
495    /// First draws `n_initial` LHS points, then iterates for `max_iter`
496    /// steps.  Returns the best `(x, y)` pair found.
497    pub fn optimize<F>(&mut self, f: F) -> (Vec<f64>, f64)
498    where
499        F: Fn(&[f64]) -> f64,
500    {
501        // --- Initial random exploration ---
502        let init_samples =
503            latin_hypercube_sample(self.opts.n_initial, self.bounds.len(), &self.bounds);
504        for x in init_samples {
505            let y = f(&x);
506            let _ = self.update(x, y);
507        }
508
509        // --- Bayesian iterations ---
510        for _ in 0..self.opts.max_iter {
511            let x_next = self.suggest_next();
512            let y_next = f(&x_next);
513            let _ = self.update(x_next, y_next);
514        }
515
516        (self.best_x.clone(), self.best_y)
517    }
518
519    /// Number of observations collected so far.
520    pub fn n_observations(&self) -> usize {
521        self.x_observed.len()
522    }
523}
524
525// ---------------------------------------------------------------------------
526// Tests
527// ---------------------------------------------------------------------------
528
529#[cfg(test)]
530mod tests {
531    use super::*;
532
533    // Helper: build a small 1-D GP fitted on a few points
534    fn simple_gp() -> GaussianProcess {
535        let mut gp = GaussianProcess::new(KernelType::Rbf, KernelParams::default());
536        let x: Vec<Vec<f64>> = vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0]];
537        let y: Vec<f64> = vec![0.0, 1.0, 0.0, -1.0];
538        gp.fit(x, y).expect("fit should succeed");
539        gp
540    }
541
542    // ---- Kernel tests ----
543
544    #[test]
545    fn test_rbf_kernel_same_point() {
546        let p = KernelParams::default();
547        let v = kernel_eval(KernelType::Rbf, &p, &[1.0, 2.0], &[1.0, 2.0]);
548        // k(x,x) = amp^2 * exp(0) = 1.0
549        assert!((v - 1.0).abs() < 1e-12);
550    }
551
552    #[test]
553    fn test_rbf_kernel_decreases_with_distance() {
554        let p = KernelParams::default();
555        let k1 = kernel_eval(KernelType::Rbf, &p, &[0.0], &[0.5]);
556        let k2 = kernel_eval(KernelType::Rbf, &p, &[0.0], &[1.5]);
557        assert!(k1 > k2, "RBF should decrease with distance");
558    }
559
560    #[test]
561    fn test_rbf_kernel_symmetry() {
562        let p = KernelParams::default();
563        let a = kernel_eval(KernelType::Rbf, &p, &[1.0, 2.0], &[3.0, 4.0]);
564        let b = kernel_eval(KernelType::Rbf, &p, &[3.0, 4.0], &[1.0, 2.0]);
565        assert!((a - b).abs() < 1e-14);
566    }
567
568    #[test]
569    fn test_matern52_kernel_same_point() {
570        let p = KernelParams::default();
571        let v = kernel_eval(KernelType::Matern52, &p, &[0.0], &[0.0]);
572        assert!((v - 1.0).abs() < 1e-12);
573    }
574
575    #[test]
576    fn test_matern52_kernel_positive() {
577        let p = KernelParams::default();
578        let v = kernel_eval(KernelType::Matern52, &p, &[0.0], &[2.0]);
579        assert!(v >= 0.0);
580    }
581
582    #[test]
583    fn test_periodic_kernel_same_point() {
584        let p = KernelParams::default();
585        let v = kernel_eval(KernelType::Periodic, &p, &[0.0], &[0.0]);
586        // sin(0) = 0 → exp(0) = 1
587        assert!((v - 1.0).abs() < 1e-12);
588    }
589
590    #[test]
591    fn test_periodic_kernel_period_recovery() {
592        // k(x, x+period) should equal k(x, x) for the periodic kernel
593        let p = KernelParams {
594            period: 2.0,
595            ..Default::default()
596        };
597        let v0 = kernel_eval(KernelType::Periodic, &p, &[0.0], &[0.0]);
598        let v1 = kernel_eval(KernelType::Periodic, &p, &[0.0], &[2.0]);
599        assert!((v0 - v1).abs() < 1e-12);
600    }
601
602    // ---- Cholesky tests ----
603
604    #[test]
605    fn test_cholesky_2x2() {
606        // A = [[4, 2],[2, 3]]  →  L = [[2, 0],[1, sqrt(2)]]
607        let a = vec![4.0, 2.0, 2.0, 3.0];
608        let l = cholesky(&a, 2).unwrap();
609        assert!((l[0] - 2.0).abs() < 1e-10);
610        assert!((l[1]).abs() < 1e-10);
611        assert!((l[2] - 1.0).abs() < 1e-10);
612        assert!((l[3] - 2_f64.sqrt()).abs() < 1e-10);
613    }
614
615    #[test]
616    fn test_cholesky_identity() {
617        let a = vec![1.0, 0.0, 0.0, 1.0];
618        let l = cholesky(&a, 2).unwrap();
619        // L should be the identity
620        assert!((l[0] - 1.0).abs() < 1e-12);
621        assert!((l[1]).abs() < 1e-12);
622        assert!((l[2]).abs() < 1e-12);
623        assert!((l[3] - 1.0).abs() < 1e-12);
624    }
625
626    #[test]
627    fn test_cholesky_not_pd_returns_err() {
628        // Matrix [[-1, 0],[0, 1]] is not PD
629        let a = vec![-1.0, 0.0, 0.0, 1.0];
630        assert!(cholesky(&a, 2).is_err());
631    }
632
633    // ---- GP fit / predict tests ----
634
635    #[test]
636    fn test_gp_fit_succeeds() {
637        let gp = simple_gp();
638        assert!(gp.is_fitted());
639        assert_eq!(gp.n_train(), 4);
640    }
641
642    #[test]
643    fn test_gp_predict_at_training_points_close() {
644        let gp = simple_gp();
645        // With very small noise the posterior mean at training points should be close
646        let (mean, var) = gp.predict(&[1.0]);
647        assert!(
648            (mean - 1.0).abs() < 0.1,
649            "mean at x=1 should be ~1, got {mean}"
650        );
651        assert!(var >= 0.0);
652    }
653
654    #[test]
655    fn test_gp_variance_nonnegative() {
656        let gp = simple_gp();
657        for x in [-1.0, 0.5, 1.5, 4.0] {
658            let (_, var) = gp.predict(&[x]);
659            assert!(var >= 0.0, "variance must be non-negative at x={x}");
660        }
661    }
662
663    #[test]
664    fn test_gp_variance_higher_far_from_data() {
665        let gp = simple_gp();
666        let (_, var_near) = gp.predict(&[1.5]);
667        let (_, var_far) = gp.predict(&[100.0]);
668        assert!(
669            var_far > var_near,
670            "variance should be higher far from training data"
671        );
672    }
673
674    #[test]
675    fn test_gp_fit_empty_panics() {
676        let result = std::panic::catch_unwind(|| {
677            let mut gp = GaussianProcess::new(KernelType::Rbf, KernelParams::default());
678            let _ = gp.fit(vec![], vec![]);
679        });
680        assert!(result.is_err(), "fit with empty data should panic");
681    }
682
683    #[test]
684    fn test_gp_matern_fit() {
685        let mut gp = GaussianProcess::new(KernelType::Matern52, KernelParams::default());
686        let x: Vec<Vec<f64>> = (0..5).map(|i| vec![i as f64]).collect();
687        let y: Vec<f64> = x.iter().map(|v| v[0].sin()).collect();
688        assert!(gp.fit(x, y).is_ok());
689        assert!(gp.is_fitted());
690    }
691
692    #[test]
693    fn test_gp_periodic_fit() {
694        let p = KernelParams {
695            period: PI,
696            ..Default::default()
697        };
698        let mut gp = GaussianProcess::new(KernelType::Periodic, p);
699        let x: Vec<Vec<f64>> = (0..6).map(|i| vec![i as f64 * 0.5]).collect();
700        let y: Vec<f64> = x.iter().map(|v| v[0].sin()).collect();
701        assert!(gp.fit(x, y).is_ok());
702    }
703
704    // ---- Acquisition function tests ----
705
706    #[test]
707    fn test_ei_nonnegative() {
708        let val = acquisition_value(
709            AcquisitionFn::ExpectedImprovement,
710            1.5,
711            0.25,
712            1.0,
713            2.0,
714            0.01,
715        );
716        assert!(val >= 0.0, "EI must be non-negative");
717    }
718
719    #[test]
720    fn test_ucb_increases_with_variance() {
721        let low_var = acquisition_value(
722            AcquisitionFn::UpperConfidenceBound,
723            1.0,
724            0.01,
725            0.0,
726            2.0,
727            0.0,
728        );
729        let high_var =
730            acquisition_value(AcquisitionFn::UpperConfidenceBound, 1.0, 1.0, 0.0, 2.0, 0.0);
731        assert!(high_var > low_var, "UCB should increase with variance");
732    }
733
734    #[test]
735    fn test_pi_in_unit_interval() {
736        let val = acquisition_value(
737            AcquisitionFn::ProbabilityOfImprovement,
738            1.5,
739            0.25,
740            1.0,
741            2.0,
742            0.0,
743        );
744        assert!((0.0..=1.0).contains(&val), "PI must be in [0, 1]");
745    }
746
747    #[test]
748    fn test_pi_zero_when_mean_below_best() {
749        // If mean is much lower than best_y, PI should be ~0
750        let val = acquisition_value(
751            AcquisitionFn::ProbabilityOfImprovement,
752            -100.0,
753            0.01,
754            1.0,
755            2.0,
756            0.0,
757        );
758        assert!(val < 0.01, "PI should be near 0 when far below best_y");
759    }
760
761    #[test]
762    fn test_ei_zero_with_negative_improvement() {
763        // mean < best_y + xi, and tiny variance → EI is effectively 0 (but ≥ 0)
764        let val = acquisition_value(
765            AcquisitionFn::ExpectedImprovement,
766            0.0,
767            1e-8,
768            10.0,
769            2.0,
770            0.01,
771        );
772        assert!(val >= 0.0);
773        assert!(val < 1e-3);
774    }
775
776    // ---- Latin-hypercube sampling tests ----
777
778    #[test]
779    fn test_lhs_shape() {
780        let samples = latin_hypercube_sample(10, 3, &[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0)]);
781        assert_eq!(samples.len(), 10);
782        for s in &samples {
783            assert_eq!(s.len(), 3);
784        }
785    }
786
787    #[test]
788    fn test_lhs_within_bounds() {
789        let bounds = vec![(2.0, 5.0), (-1.0, 1.0)];
790        let samples = latin_hypercube_sample(20, 2, &bounds);
791        for s in &samples {
792            assert!(s[0] >= 2.0 && s[0] <= 5.0);
793            assert!(s[1] >= -1.0 && s[1] <= 1.0);
794        }
795    }
796
797    #[test]
798    fn test_lhs_zero_samples() {
799        let s = latin_hypercube_sample(0, 2, &[(0.0, 1.0), (0.0, 1.0)]);
800        assert!(s.is_empty());
801    }
802
803    #[test]
804    fn test_lhs_coverage() {
805        // With n=4 samples in 1D [0,4], each unit interval should be covered
806        let samples = latin_hypercube_sample(4, 1, &[(0.0, 4.0)]);
807        let mut covered = [false; 4];
808        for s in &samples {
809            let slot = (s[0] as usize).min(3);
810            covered[slot] = true;
811        }
812        assert!(
813            covered.iter().all(|&c| c),
814            "each interval should be covered"
815        );
816    }
817
818    // ---- BayesianOptimizer tests ----
819
820    #[test]
821    fn test_optimizer_update_increments_count() {
822        let mut opt = BayesianOptimizer::new(
823            vec![(0.0, 1.0)],
824            KernelType::Rbf,
825            KernelParams::default(),
826            BayesOpts::default(),
827        );
828        opt.update(vec![0.5], 1.0).unwrap();
829        opt.update(vec![0.7], 2.0).unwrap();
830        assert_eq!(opt.n_observations(), 2);
831    }
832
833    #[test]
834    fn test_optimizer_tracks_best() {
835        let mut opt = BayesianOptimizer::new(
836            vec![(0.0, 1.0)],
837            KernelType::Rbf,
838            KernelParams::default(),
839            BayesOpts::default(),
840        );
841        opt.update(vec![0.1], 0.5).unwrap();
842        opt.update(vec![0.9], 3.0).unwrap();
843        opt.update(vec![0.5], 1.0).unwrap();
844        assert!((opt.best_y - 3.0).abs() < 1e-12);
845        assert!((opt.best_x[0] - 0.9).abs() < 1e-12);
846    }
847
848    #[test]
849    fn test_optimizer_suggest_before_fit_returns_point() {
850        let opt = BayesianOptimizer::new(
851            vec![(0.0, 1.0), (0.0, 1.0)],
852            KernelType::Rbf,
853            KernelParams::default(),
854            BayesOpts::default(),
855        );
856        let x = opt.suggest_next();
857        assert_eq!(x.len(), 2);
858        assert!(x[0] >= 0.0 && x[0] <= 1.0);
859        assert!(x[1] >= 0.0 && x[1] <= 1.0);
860    }
861
862    #[test]
863    fn test_optimizer_convergence_quadratic() {
864        // Maximize f(x) = -(x - 0.3)^2 over [0, 1] — optimum at x=0.3
865        let mut opt = BayesianOptimizer::new(
866            vec![(0.0, 1.0)],
867            KernelType::Rbf,
868            KernelParams::default(),
869            BayesOpts {
870                n_initial: 5,
871                max_iter: 15,
872                n_candidates: 256,
873                acquisition: AcquisitionFn::ExpectedImprovement,
874                ..BayesOpts::default()
875            },
876        );
877        let (best_x, best_y) = opt.optimize(|x| -(x[0] - 0.3).powi(2));
878        // We expect to get reasonably close to the true optimum (y=0)
879        assert!(
880            best_y > -0.1,
881            "optimizer should find near-optimum, got y={best_y}"
882        );
883        assert!(
884            (best_x[0] - 0.3).abs() < 0.4,
885            "optimizer should find x~0.3, got x={}",
886            best_x[0]
887        );
888    }
889
890    #[test]
891    fn test_optimizer_convergence_sinusoidal() {
892        // Maximize sin(2π x) over [0, 1] — global max at x=0.25
893        let mut opt = BayesianOptimizer::new(
894            vec![(0.0, 1.0)],
895            KernelType::Rbf,
896            KernelParams::default(),
897            BayesOpts {
898                n_initial: 8,
899                max_iter: 20,
900                n_candidates: 512,
901                acquisition: AcquisitionFn::UpperConfidenceBound,
902                kappa: 2.0,
903                ..BayesOpts::default()
904            },
905        );
906        let (_best_x, best_y) = opt.optimize(|x| (2.0 * PI * x[0]).sin());
907        assert!(best_y > 0.9, "should reach sin peak, got {best_y}");
908    }
909
910    #[test]
911    fn test_optimizer_2d_convergence() {
912        // Maximize -(x^2 + y^2) over [-2,2]^2 — optimum at (0,0)
913        let mut opt = BayesianOptimizer::new(
914            vec![(-2.0, 2.0), (-2.0, 2.0)],
915            KernelType::Rbf,
916            KernelParams {
917                length_scale: 1.5,
918                ..KernelParams::default()
919            },
920            BayesOpts {
921                n_initial: 6,
922                max_iter: 20,
923                n_candidates: 512,
924                acquisition: AcquisitionFn::ExpectedImprovement,
925                ..BayesOpts::default()
926            },
927        );
928        let (_best_x, best_y) = opt.optimize(|x| -(x[0].powi(2) + x[1].powi(2)));
929        assert!(best_y > -1.0, "should find near-origin, got y={best_y}");
930    }
931
932    #[test]
933    fn test_standard_normal_cdf_symmetry() {
934        // Φ(-z) = 1 - Φ(z)
935        for z in [-2.0, -1.0, 0.0, 1.0, 2.0] {
936            let sum = standard_normal_cdf(z) + standard_normal_cdf(-z);
937            assert!((sum - 1.0).abs() < 1e-6, "CDF symmetry failed at z={z}");
938        }
939    }
940
941    #[test]
942    fn test_standard_normal_cdf_midpoint() {
943        assert!((standard_normal_cdf(0.0) - 0.5).abs() < 1e-6);
944    }
945
946    #[test]
947    fn test_erf_known_values() {
948        // erf(0) ≈ 0 (A&S approximation error < 1e-8)
949        assert!(libm_erf(0.0).abs() < 1e-8);
950        // erf(∞) ≈ 1
951        assert!((libm_erf(5.0) - 1.0).abs() < 1e-5);
952        // erf(-x) = -erf(x)
953        assert!((libm_erf(-1.0) + libm_erf(1.0)).abs() < 1e-10);
954    }
955}