Skip to main content

finance_query/backtesting/optimizer/
bayesian.rs

1//! Sequential model-based (Bayesian) parameter optimisation.
2//!
3//! [`BayesianSearch`] finds near-optimal strategy parameters in far fewer
4//! backtests than exhaustive [`GridSearch`] — typically 50–200 evaluations
5//! instead of thousands — by building a statistical surrogate model of the
6//! objective and directing search toward promising, under-explored regions.
7//!
8//! # Algorithm (SAMBO — Sequential Adaptive Model-Based Optimisation)
9//!
10//! 1. **Exploration phase** — Sample `initial_points` parameter sets using
11//!    [Latin Hypercube Sampling] (LHS) to guarantee good initial coverage of
12//!    the search space.
13//! 2. **Sequential phase** — Fit a [Nadaraya-Watson kernel regression]
14//!    surrogate to all `(params, score)` observations. Generate `N_CANDIDATES`
15//!    random candidates and score each with the [Upper Confidence Bound] (UCB)
16//!    acquisition function `a(x) = μ(x) + β·σ(x)`. Run the backtest for the
17//!    highest-scoring candidate, add the observation, and repeat.
18//! 3. **Convergence** — Stop when `max_evaluations` is reached.
19//!
20//! [Latin Hypercube Sampling]: https://en.wikipedia.org/wiki/Latin_hypercube_sampling
21//! [Nadaraya-Watson kernel regression]: https://en.wikipedia.org/wiki/Kernel_regression
22//! [Upper Confidence Bound]: https://en.wikipedia.org/wiki/Multi-armed_bandit#Upper_confidence_bound
23//!
24//! # Example
25//!
26//! ```ignore
27//! use finance_query::backtesting::{
28//!     BacktestConfig, SmaCrossover,
29//!     optimizer::{BayesianSearch, OptimizeMetric, ParamRange},
30//! };
31//!
32//! # fn example(candles: &[finance_query::models::chart::Candle]) {
33//! let report = BayesianSearch::new()
34//!     .param("fast", ParamRange::int_bounds(5, 50))
35//!     .param("slow", ParamRange::int_bounds(20, 200))
36//!     .param("rsi_period", ParamRange::int_bounds(7, 21))
37//!     .param("threshold", ParamRange::float_bounds(0.3, 0.7))
38//!     .optimize_for(OptimizeMetric::SharpeRatio)
39//!     .max_evaluations(100)
40//!     .run("AAPL", &candles, &BacktestConfig::default(), |params| {
41//!         SmaCrossover::new(
42//!             params["fast"].as_int() as usize,
43//!             params["slow"].as_int() as usize,
44//!         )
45//!     })
46//!     .unwrap();
47//!
48//! println!("Best params:  {:?}", report.best.params);
49//! println!("Best Sharpe:  {:.2}", report.best.result.metrics.sharpe_ratio);
50//! println!("Evaluations:  {}", report.n_evaluations);
51//! # }
52//! ```
53
54use std::collections::HashMap;
55
56use crate::models::chart::Candle;
57
58use super::super::config::BacktestConfig;
59use super::super::engine::BacktestEngine;
60use super::super::error::{BacktestError, Result};
61use super::super::monte_carlo::Xorshift64;
62use super::super::strategy::Strategy;
63use super::{
64    OptimizationReport, OptimizationResult, OptimizeMetric, ParamRange, ParamValue,
65    sort_results_best_first,
66};
67
68// ── Defaults ──────────────────────────────────────────────────────────────────
69
70const DEFAULT_MAX_EVALUATIONS: usize = 100;
71const DEFAULT_INITIAL_POINTS: usize = 10;
72/// β = 2.0 balances exploitation and exploration for objectives in [0, 1].
73const DEFAULT_UCB_BETA: f64 = 2.0;
74const DEFAULT_SEED: u64 = 42;
75/// Candidates evaluated per acquisition step. 1 000 reliably finds the UCB
76/// maximum without meaningful overhead (pure floating-point math, no backtests).
77const N_CANDIDATES: usize = 1_000;
78
79// ── BayesianSearch ────────────────────────────────────────────────────────────
80
81/// Sequential model-based (Bayesian) parameter optimiser.
82///
83/// Finds near-optimal strategy parameters in a fraction of the evaluations
84/// required by exhaustive [`GridSearch`], making it practical for
85/// high-dimensional spaces or continuous float ranges.
86///
87/// Returns the same [`OptimizationReport`] as [`GridSearch`], so the two are
88/// drop-in interchangeable and both work with [`WalkForwardConfig`].
89///
90/// # Overfitting Warning
91///
92/// Results are **in-sample only**. Follow up with [`WalkForwardConfig`] or a
93/// held-out test window to obtain an unbiased out-of-sample estimate.
94///
95/// [`WalkForwardConfig`]: super::super::walk_forward::WalkForwardConfig
96#[derive(Debug, Clone, Default)]
97pub struct BayesianSearch {
98    params: Vec<(String, ParamRange)>,
99    metric: Option<OptimizeMetric>,
100    max_evaluations: Option<usize>,
101    initial_points: Option<usize>,
102    ucb_beta: Option<f64>,
103    seed: Option<u64>,
104}
105
106impl BayesianSearch {
107    /// Create a new Bayesian search with no parameters defined yet.
108    pub fn new() -> Self {
109        Self::default()
110    }
111
112    /// Add a named parameter range to search over.
113    ///
114    /// Use [`ParamRange::int_bounds`] / [`ParamRange::float_bounds`] for
115    /// continuous ranges (recommended) or any [`ParamRange`] variant.
116    pub fn param(mut self, name: impl Into<String>, range: ParamRange) -> Self {
117        self.params.push((name.into(), range));
118        self
119    }
120
121    /// Set the metric to optimise for (defaults to [`OptimizeMetric::SharpeRatio`]).
122    pub fn optimize_for(mut self, metric: OptimizeMetric) -> Self {
123        self.metric = Some(metric);
124        self
125    }
126
127    /// Maximum total strategy evaluations, including the initial LHS phase (default: 100).
128    pub fn max_evaluations(mut self, n: usize) -> Self {
129        self.max_evaluations = Some(n);
130        self
131    }
132
133    /// Number of initial random (LHS) samples before the surrogate is fitted (default: 10).
134    ///
135    /// Clamped to `[2, max_evaluations]`. More initial points improve surrogate
136    /// quality at the cost of fewer sequential refinement steps.
137    pub fn initial_points(mut self, n: usize) -> Self {
138        self.initial_points = Some(n);
139        self
140    }
141
142    /// UCB exploration–exploitation coefficient β (default: 2.0).
143    ///
144    /// Higher values drive broader exploration of uncertain regions;
145    /// lower values concentrate search near already-good parameter sets.
146    pub fn ucb_beta(mut self, beta: f64) -> Self {
147        self.ucb_beta = Some(beta);
148        self
149    }
150
151    /// PRNG seed for reproducible runs (default: 42).
152    pub fn seed(mut self, seed: u64) -> Self {
153        self.seed = Some(seed);
154        self
155    }
156
157    /// Run the Bayesian search.
158    ///
159    /// `symbol` is used only for labelling in the returned results.
160    ///
161    /// `factory` receives the current parameter map and returns a strategy
162    /// instance. Parameter sets incompatible with the candle series (warmup
163    /// too long) are silently skipped.
164    ///
165    /// Returns an error only when no parameters are defined or every evaluation
166    /// was skipped due to insufficient data.
167    pub fn run<S, F>(
168        &self,
169        symbol: &str,
170        candles: &[Candle],
171        config: &BacktestConfig,
172        factory: F,
173    ) -> Result<OptimizationReport>
174    where
175        S: Strategy,
176        F: Fn(&HashMap<String, ParamValue>) -> S,
177    {
178        if self.params.is_empty() {
179            return Err(BacktestError::invalid_param(
180                "params",
181                "BayesianSearch requires at least one parameter range",
182            ));
183        }
184
185        let d = self.params.len();
186        let metric = self.metric.unwrap_or(OptimizeMetric::SharpeRatio);
187        let max_eval = self.max_evaluations.unwrap_or(DEFAULT_MAX_EVALUATIONS);
188        let n_init = self
189            .initial_points
190            .unwrap_or(DEFAULT_INITIAL_POINTS)
191            .max(2)
192            .min(max_eval);
193        let beta = self.ucb_beta.unwrap_or(DEFAULT_UCB_BETA);
194        let seed = self.seed.unwrap_or(DEFAULT_SEED);
195
196        let mut rng = Xorshift64::new(seed);
197        // (unit-hypercube coords, metric score) for all successful evaluations.
198        let mut observations: Vec<(Vec<f64>, f64)> = Vec::with_capacity(max_eval);
199        let mut all_results: Vec<OptimizationResult> = Vec::with_capacity(max_eval);
200        // Running best score after each successful evaluation (non-decreasing).
201        let mut convergence_curve: Vec<f64> = Vec::with_capacity(max_eval);
202        let mut n_evaluations: usize = 0;
203        let mut best_score: Option<f64> = None;
204
205        // ── Phase 1: Latin Hypercube initial sampling ──────────────────────────
206
207        for norm_point in latin_hypercube_sample(n_init, d, &mut rng) {
208            n_evaluations += 1;
209            if let Some(opt_result) = run_one(
210                symbol,
211                candles,
212                config,
213                &metric,
214                &factory,
215                &norm_point,
216                &self.params,
217            ) {
218                let score = metric.score(&opt_result.result);
219                if score.is_finite() {
220                    update_best(&mut best_score, score);
221                    observations.push((norm_point, score));
222                }
223                if let Some(b) = best_score {
224                    convergence_curve.push(b);
225                }
226                all_results.push(opt_result);
227            }
228        }
229
230        // ── Phase 2: Sequential surrogate-guided search ────────────────────────
231
232        for _ in 0..max_eval.saturating_sub(n_init) {
233            let norm_point = if observations.len() < 2 {
234                // Too few observations for a reliable surrogate — fall back to random.
235                (0..d).map(|_| rng.next_f64_positive()).collect()
236            } else {
237                let surrogate = Surrogate::fit(&observations, beta);
238                (0..N_CANDIDATES)
239                    .map(|_| {
240                        (0..d)
241                            .map(|_| rng.next_f64_positive())
242                            .collect::<Vec<f64>>()
243                    })
244                    .max_by(|a, b| {
245                        surrogate
246                            .acquisition(a)
247                            .partial_cmp(&surrogate.acquisition(b))
248                            .unwrap_or(std::cmp::Ordering::Equal)
249                    })
250                    // SAFETY: N_CANDIDATES > 0.
251                    .unwrap()
252            };
253
254            n_evaluations += 1;
255            if let Some(opt_result) = run_one(
256                symbol,
257                candles,
258                config,
259                &metric,
260                &factory,
261                &norm_point,
262                &self.params,
263            ) {
264                let score = metric.score(&opt_result.result);
265                if score.is_finite() {
266                    update_best(&mut best_score, score);
267                    observations.push((norm_point, score));
268                }
269                if let Some(b) = best_score {
270                    convergence_curve.push(b);
271                }
272                all_results.push(opt_result);
273            }
274        }
275
276        // ── Finalise ───────────────────────────────────────────────────────────
277
278        if all_results.is_empty() {
279            return Err(BacktestError::invalid_param(
280                "candles",
281                "no parameter set had enough data to run a backtest",
282            ));
283        }
284
285        sort_results_best_first(&mut all_results, metric);
286
287        if metric.score(&all_results[0].result).is_nan() {
288            return Err(BacktestError::invalid_param(
289                "metric",
290                "all parameter sets produced NaN for the target metric",
291            ));
292        }
293
294        let strategy_name = all_results[0].result.strategy_name.clone();
295        let best = all_results[0].clone();
296        let total_combinations = all_results.len();
297
298        Ok(OptimizationReport {
299            strategy_name,
300            total_combinations,
301            results: all_results,
302            best,
303            skipped_errors: 0,
304            convergence_curve,
305            n_evaluations,
306        })
307    }
308}
309
310// ── Internal helpers ──────────────────────────────────────────────────────────
311
312#[inline]
313fn update_best(best: &mut Option<f64>, score: f64) {
314    match best {
315        None => *best = Some(score),
316        Some(b) if score > *b => *b = score,
317        _ => {}
318    }
319}
320
321/// Run one backtest for a unit-hypercube point; returns `None` for
322/// `InsufficientData` errors (silently skipped).
323fn run_one<S, F>(
324    symbol: &str,
325    candles: &[Candle],
326    config: &BacktestConfig,
327    _metric: &OptimizeMetric,
328    factory: &F,
329    norm_point: &[f64],
330    param_specs: &[(String, ParamRange)],
331) -> Option<OptimizationResult>
332where
333    S: Strategy,
334    F: Fn(&HashMap<String, ParamValue>) -> S,
335{
336    let params = denormalize(norm_point, param_specs);
337    let strategy = factory(&params);
338    match BacktestEngine::new(config.clone()).run(symbol, candles, strategy) {
339        Ok(result) => Some(OptimizationResult { params, result }),
340        Err(BacktestError::InsufficientData { .. }) => None,
341        Err(e) => {
342            tracing::warn!(
343                params = ?params,
344                error = %e,
345                "BayesianSearch: skipping candidate due to unexpected error"
346            );
347            None
348        }
349    }
350}
351
352/// Convert unit-hypercube coordinates `t[i] ∈ (0, 1]` into named [`ParamValue`]s.
353fn denormalize(
354    norm_point: &[f64],
355    param_specs: &[(String, ParamRange)],
356) -> HashMap<String, ParamValue> {
357    norm_point
358        .iter()
359        .zip(param_specs.iter())
360        .map(|(&t, (name, range))| (name.clone(), range.sample_at(t)))
361        .collect()
362}
363
364// ── Latin Hypercube Sampling ──────────────────────────────────────────────────
365
366/// Generate `n` stratified random samples in the `d`-dimensional unit hypercube.
367///
368/// Each dimension is divided into `n` equal strata; exactly one sample is drawn
369/// from each stratum per dimension. Stratum assignments are independently
370/// shuffled across dimensions, giving good marginal coverage with low
371/// inter-dimension correlation — significantly better than IID uniform sampling.
372fn latin_hypercube_sample(n: usize, d: usize, rng: &mut Xorshift64) -> Vec<Vec<f64>> {
373    if n == 0 {
374        return vec![];
375    }
376
377    let mut samples = vec![vec![0.0_f64; d]; n];
378
379    #[allow(clippy::needless_range_loop)]
380    for dim in 0..d {
381        // One value per stratum [i/n, (i+1)/n).
382        let mut stratum_values: Vec<f64> = (0..n)
383            .map(|i| {
384                let lo = i as f64 / n as f64;
385                let hi = (i + 1) as f64 / n as f64;
386                lo + rng.next_f64_positive() * (hi - lo)
387            })
388            .collect();
389
390        // Fisher-Yates shuffle of stratum assignments for this dimension.
391        for i in (1..n).rev() {
392            let j = rng.next_usize(i + 1);
393            stratum_values.swap(i, j);
394        }
395
396        for i in 0..n {
397            samples[i][dim] = stratum_values[i];
398        }
399    }
400
401    samples
402}
403
404// ── Surrogate model ───────────────────────────────────────────────────────────
405
406/// Nadaraya-Watson kernel regression surrogate with UCB acquisition.
407///
408/// Given observed `(x, y)` pairs (unit-hypercube coords and metric scores),
409/// models the objective surface as a Gaussian-kernel-weighted average.
410///
411/// **Why kernel regression?** It is dependency-free, numerically stable,
412/// non-parametric, and the mean/variance formulas are five lines of arithmetic.
413/// The trade-off vs. a Gaussian Process is that it does not provide a
414/// calibrated predictive distribution, but UCB acquisition works well in
415/// practice for backtesting parameter search.
416struct Surrogate<'a> {
417    observations: &'a [(Vec<f64>, f64)],
418    beta: f64,
419    /// Pre-computed `2h²` denominator for the RBF kernel exponent.
420    bandwidth_sq: f64,
421}
422
423impl<'a> Surrogate<'a> {
424    /// Fit the surrogate to a set of `(unit-hypercube coords, score)` pairs.
425    ///
426    /// Bandwidth: `h = n^(-1/(d+4))` (Silverman-inspired), floored at 0.1 to
427    /// avoid near-degenerate kernels with very few data points.
428    fn fit(observations: &'a [(Vec<f64>, f64)], beta: f64) -> Self {
429        let n = observations.len() as f64;
430        let d = observations.first().map_or(1, |(x, _)| x.len()) as f64;
431        let h = n.powf(-1.0 / (d + 4.0)).max(0.1);
432        Self {
433            observations,
434            beta,
435            bandwidth_sq: 2.0 * h * h,
436        }
437    }
438
439    /// UCB acquisition: `μ(x) + β·σ(x)`.
440    fn acquisition(&self, x: &[f64]) -> f64 {
441        let (mean, std) = self.predict(x);
442        mean + self.beta * std
443    }
444
445    /// Nadaraya-Watson mean and weighted standard deviation at `x`.
446    ///
447    /// Returns `(0.0, 1.0)` — maximum uncertainty — when all observations are
448    /// too distant to contribute meaningful kernel weight.
449    fn predict(&self, x: &[f64]) -> (f64, f64) {
450        let mut w_sum = 0.0_f64;
451        let mut wy_sum = 0.0_f64;
452
453        for (xi, yi) in self.observations {
454            let w = self.rbf(x, xi);
455            w_sum += w;
456            wy_sum += w * yi;
457        }
458
459        if w_sum < f64::EPSILON {
460            return (0.0, 1.0);
461        }
462
463        let mean = wy_sum / w_sum;
464
465        let mut wvar = 0.0_f64;
466        for (xi, yi) in self.observations {
467            let diff = yi - mean;
468            wvar += self.rbf(x, xi) * diff * diff;
469        }
470        let std = (wvar / w_sum).max(0.0).sqrt();
471
472        (mean, std)
473    }
474
475    /// Gaussian (RBF) kernel: `exp(-‖x − xᵢ‖² / (2h²))`.
476    #[inline]
477    fn rbf(&self, x: &[f64], xi: &[f64]) -> f64 {
478        let dist_sq: f64 = x.iter().zip(xi.iter()).map(|(a, b)| (a - b).powi(2)).sum();
479        (-dist_sq / self.bandwidth_sq).exp()
480    }
481}
482
483// ── Tests ─────────────────────────────────────────────────────────────────────
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488    use crate::backtesting::{BacktestConfig, SmaCrossover};
489    use crate::models::chart::Candle;
490
491    fn make_candles(prices: &[f64]) -> Vec<Candle> {
492        prices
493            .iter()
494            .enumerate()
495            .map(|(i, &p)| Candle {
496                timestamp: i as i64,
497                open: p,
498                high: p * 1.01,
499                low: p * 0.99,
500                close: p,
501                volume: 1_000,
502                adj_close: Some(p),
503            })
504            .collect()
505    }
506
507    fn trending_prices(n: usize) -> Vec<f64> {
508        (0..n).map(|i| 100.0 + i as f64 * 0.5).collect()
509    }
510
511    // ── LHS ───────────────────────────────────────────────────────────────────
512
513    #[test]
514    fn test_lhs_shape() {
515        let mut rng = Xorshift64::new(1);
516        let samples = latin_hypercube_sample(8, 3, &mut rng);
517        assert_eq!(samples.len(), 8);
518        assert!(samples.iter().all(|p| p.len() == 3));
519    }
520
521    #[test]
522    fn test_lhs_stratification() {
523        let n = 10;
524        let mut rng = Xorshift64::new(99);
525        let samples = latin_hypercube_sample(n, 2, &mut rng);
526
527        for dim in 0..2 {
528            let mut counts = vec![0usize; n];
529            for point in &samples {
530                let stratum = (point[dim] * n as f64).floor() as usize;
531                counts[stratum.min(n - 1)] += 1;
532            }
533            assert!(
534                counts.iter().all(|&c| c == 1),
535                "dim {dim}: expected one sample per stratum, got {counts:?}"
536            );
537        }
538    }
539
540    #[test]
541    fn test_lhs_values_in_unit_cube() {
542        let mut rng = Xorshift64::new(7);
543        for point in latin_hypercube_sample(20, 4, &mut rng) {
544            for v in point {
545                assert!(v > 0.0 && v <= 1.0, "value {v} outside (0, 1]");
546            }
547        }
548    }
549
550    // ── Surrogate ─────────────────────────────────────────────────────────────
551
552    #[test]
553    fn test_surrogate_predicts_near_observation() {
554        let obs = vec![(vec![0.5_f64], 1.0_f64)];
555        let s = Surrogate::fit(&obs, 2.0);
556        let (mean, _) = s.predict(&[0.5]);
557        assert!((mean - 1.0).abs() < 1e-6);
558    }
559
560    /// A point so far from all observations that `exp(-dist²/2h²) < ε` triggers
561    /// the maximum-uncertainty fallback path, returning `(0.0, 1.0)`.
562    #[test]
563    fn test_surrogate_max_uncertainty_fallback_for_very_distant_point() {
564        // At x=100 the kernel weight is exp(-10000/bandwidth_sq) which underflows
565        // to exactly 0.0 in f64, so w_sum < EPSILON and the fallback is taken.
566        let obs = vec![(vec![0.0_f64], 0.5_f64), (vec![0.1], 0.6)];
567        let s = Surrogate::fit(&obs, 2.0);
568        let (mean, std) = s.predict(&[100.0]);
569        assert!(
570            (mean - 0.0).abs() < 1e-6,
571            "expected fallback mean=0.0, got {mean}"
572        );
573        assert!(
574            (std - 1.0).abs() < 1e-6,
575            "expected fallback std=1.0, got {std}"
576        );
577    }
578
579    /// When two nearby observations have very different scores, the surrogate
580    /// should report non-trivial variance at the midpoint.
581    #[test]
582    fn test_surrogate_std_nonzero_with_disagreeing_observations() {
583        let obs = vec![(vec![0.0_f64], 0.1_f64), (vec![0.05], 0.9)];
584        let s = Surrogate::fit(&obs, 2.0);
585        let (_, std) = s.predict(&[0.025]); // midpoint — equal weight to both
586        assert!(
587            std > 0.1,
588            "expected non-trivial std for disagreeing observations, got {std}"
589        );
590    }
591
592    #[test]
593    fn test_acquisition_favours_uncertain_regions_with_high_beta() {
594        let obs = vec![(vec![0.0_f64], 0.5_f64), (vec![0.1], 0.6)];
595        let s = Surrogate::fit(&obs, 10.0); // high β → exploration-heavy
596        assert!(
597            s.acquisition(&[1.0]) > s.acquisition(&[0.05]),
598            "far point should have higher UCB with β=10"
599        );
600    }
601
602    // ── BayesianSearch integration ────────────────────────────────────────────
603
604    #[test]
605    fn test_bayesian_search_runs() {
606        let candles = make_candles(&trending_prices(200));
607        let config = BacktestConfig::builder()
608            .commission_pct(0.0)
609            .slippage_pct(0.0)
610            .build()
611            .unwrap();
612
613        let report = BayesianSearch::new()
614            .param("fast", ParamRange::int_bounds(3, 10))
615            .param("slow", ParamRange::int_bounds(10, 30))
616            .optimize_for(OptimizeMetric::TotalReturn)
617            .max_evaluations(20)
618            .seed(1)
619            .run("TEST", &candles, &config, |params| {
620                SmaCrossover::new(
621                    params["fast"].as_int() as usize,
622                    params["slow"].as_int() as usize,
623                )
624            })
625            .unwrap();
626
627        assert!(!report.results.is_empty());
628        assert_eq!(report.strategy_name, "SMA Crossover");
629        assert!(report.n_evaluations > 0);
630        assert!(!report.convergence_curve.is_empty());
631    }
632
633    #[test]
634    fn test_convergence_curve_is_nondecreasing() {
635        let candles = make_candles(&trending_prices(200));
636        let config = BacktestConfig::builder()
637            .commission_pct(0.0)
638            .slippage_pct(0.0)
639            .build()
640            .unwrap();
641
642        let report = BayesianSearch::new()
643            .param("fast", ParamRange::int_bounds(3, 15))
644            .param("slow", ParamRange::int_bounds(15, 40))
645            .max_evaluations(30)
646            .seed(2)
647            .run("TEST", &candles, &config, |params| {
648                SmaCrossover::new(
649                    params["fast"].as_int() as usize,
650                    params["slow"].as_int() as usize,
651                )
652            })
653            .unwrap();
654
655        for window in report.convergence_curve.windows(2) {
656            assert!(
657                window[1] >= window[0] - 1e-12,
658                "convergence curve not non-decreasing: {window:?}"
659            );
660        }
661    }
662
663    #[test]
664    fn test_results_sorted_best_first() {
665        let candles = make_candles(&trending_prices(150));
666        let config = BacktestConfig::builder()
667            .commission_pct(0.0)
668            .slippage_pct(0.0)
669            .build()
670            .unwrap();
671
672        let report = BayesianSearch::new()
673            .param("fast", ParamRange::int_bounds(3, 10))
674            .param("slow", ParamRange::int_bounds(10, 25))
675            .optimize_for(OptimizeMetric::TotalReturn)
676            .max_evaluations(15)
677            .seed(3)
678            .run("TEST", &candles, &config, |params| {
679                SmaCrossover::new(
680                    params["fast"].as_int() as usize,
681                    params["slow"].as_int() as usize,
682                )
683            })
684            .unwrap();
685
686        if report.results.len() > 1 {
687            let first = OptimizeMetric::TotalReturn.score(&report.results[0].result);
688            let second = OptimizeMetric::TotalReturn.score(&report.results[1].result);
689            assert!(first >= second - 1e-12);
690        }
691    }
692
693    #[test]
694    fn test_best_matches_results_first() {
695        let candles = make_candles(&trending_prices(150));
696        let config = BacktestConfig::builder()
697            .commission_pct(0.0)
698            .slippage_pct(0.0)
699            .build()
700            .unwrap();
701
702        let report = BayesianSearch::new()
703            .param("fast", ParamRange::int_bounds(3, 10))
704            .param("slow", ParamRange::int_bounds(10, 25))
705            .max_evaluations(15)
706            .seed(4)
707            .run("TEST", &candles, &config, |params| {
708                SmaCrossover::new(
709                    params["fast"].as_int() as usize,
710                    params["slow"].as_int() as usize,
711                )
712            })
713            .unwrap();
714
715        let best = OptimizeMetric::SharpeRatio.score(&report.best.result);
716        let first = OptimizeMetric::SharpeRatio.score(&report.results[0].result);
717        assert!((best - first).abs() < 1e-12);
718    }
719
720    #[test]
721    fn test_no_params_returns_error() {
722        let candles = make_candles(&trending_prices(100));
723        let config = BacktestConfig::default();
724        assert!(
725            BayesianSearch::new()
726                .run("TEST", &candles, &config, |_| SmaCrossover::new(5, 20))
727                .is_err()
728        );
729    }
730
731    #[test]
732    fn test_seeded_runs_are_reproducible() {
733        let candles = make_candles(&trending_prices(200));
734        let config = BacktestConfig::builder()
735            .commission_pct(0.0)
736            .slippage_pct(0.0)
737            .build()
738            .unwrap();
739
740        let search = BayesianSearch::new()
741            .param("fast", ParamRange::int_bounds(3, 12))
742            .param("slow", ParamRange::int_bounds(12, 30))
743            .max_evaluations(15)
744            .seed(77);
745
746        let factory = |p: &HashMap<String, ParamValue>| {
747            SmaCrossover::new(p["fast"].as_int() as usize, p["slow"].as_int() as usize)
748        };
749
750        let r1 = search
751            .clone()
752            .run("TEST", &candles, &config, factory)
753            .unwrap();
754        let r2 = search.run("TEST", &candles, &config, factory).unwrap();
755
756        assert_eq!(r1.n_evaluations, r2.n_evaluations);
757        assert_eq!(r1.convergence_curve, r2.convergence_curve);
758        assert_eq!(
759            r1.best.result.metrics.total_return_pct,
760            r2.best.result.metrics.total_return_pct
761        );
762    }
763}