1use std::collections::HashMap;
55
56use crate::models::chart::Candle;
57
58use super::super::config::BacktestConfig;
59use super::super::engine::BacktestEngine;
60use super::super::error::{BacktestError, Result};
61use super::super::monte_carlo::Xorshift64;
62use super::super::strategy::Strategy;
63use super::{
64 OptimizationReport, OptimizationResult, OptimizeMetric, ParamRange, ParamValue,
65 sort_results_best_first,
66};
67
68const DEFAULT_MAX_EVALUATIONS: usize = 100;
71const DEFAULT_INITIAL_POINTS: usize = 10;
72const DEFAULT_UCB_BETA: f64 = 2.0;
74const DEFAULT_SEED: u64 = 42;
75const N_CANDIDATES: usize = 1_000;
78
79#[derive(Debug, Clone, Default)]
97pub struct BayesianSearch {
98 params: Vec<(String, ParamRange)>,
99 metric: Option<OptimizeMetric>,
100 max_evaluations: Option<usize>,
101 initial_points: Option<usize>,
102 ucb_beta: Option<f64>,
103 seed: Option<u64>,
104}
105
106impl BayesianSearch {
107 pub fn new() -> Self {
109 Self::default()
110 }
111
112 pub fn param(mut self, name: impl Into<String>, range: ParamRange) -> Self {
117 self.params.push((name.into(), range));
118 self
119 }
120
121 pub fn optimize_for(mut self, metric: OptimizeMetric) -> Self {
123 self.metric = Some(metric);
124 self
125 }
126
127 pub fn max_evaluations(mut self, n: usize) -> Self {
129 self.max_evaluations = Some(n);
130 self
131 }
132
133 pub fn initial_points(mut self, n: usize) -> Self {
138 self.initial_points = Some(n);
139 self
140 }
141
142 pub fn ucb_beta(mut self, beta: f64) -> Self {
147 self.ucb_beta = Some(beta);
148 self
149 }
150
151 pub fn seed(mut self, seed: u64) -> Self {
153 self.seed = Some(seed);
154 self
155 }
156
157 pub fn run<S, F>(
168 &self,
169 symbol: &str,
170 candles: &[Candle],
171 config: &BacktestConfig,
172 factory: F,
173 ) -> Result<OptimizationReport>
174 where
175 S: Strategy,
176 F: Fn(&HashMap<String, ParamValue>) -> S,
177 {
178 if self.params.is_empty() {
179 return Err(BacktestError::invalid_param(
180 "params",
181 "BayesianSearch requires at least one parameter range",
182 ));
183 }
184
185 let d = self.params.len();
186 let metric = self.metric.unwrap_or(OptimizeMetric::SharpeRatio);
187 let max_eval = self.max_evaluations.unwrap_or(DEFAULT_MAX_EVALUATIONS);
188 let n_init = self
189 .initial_points
190 .unwrap_or(DEFAULT_INITIAL_POINTS)
191 .max(2)
192 .min(max_eval);
193 let beta = self.ucb_beta.unwrap_or(DEFAULT_UCB_BETA);
194 let seed = self.seed.unwrap_or(DEFAULT_SEED);
195
196 let mut rng = Xorshift64::new(seed);
197 let mut observations: Vec<(Vec<f64>, f64)> = Vec::with_capacity(max_eval);
199 let mut all_results: Vec<OptimizationResult> = Vec::with_capacity(max_eval);
200 let mut convergence_curve: Vec<f64> = Vec::with_capacity(max_eval);
202 let mut n_evaluations: usize = 0;
203 let mut best_score: Option<f64> = None;
204
205 for norm_point in latin_hypercube_sample(n_init, d, &mut rng) {
208 n_evaluations += 1;
209 if let Some(opt_result) = run_one(
210 symbol,
211 candles,
212 config,
213 &metric,
214 &factory,
215 &norm_point,
216 &self.params,
217 ) {
218 let score = metric.score(&opt_result.result);
219 if score.is_finite() {
220 update_best(&mut best_score, score);
221 observations.push((norm_point, score));
222 }
223 if let Some(b) = best_score {
224 convergence_curve.push(b);
225 }
226 all_results.push(opt_result);
227 }
228 }
229
230 for _ in 0..max_eval.saturating_sub(n_init) {
233 let norm_point = if observations.len() < 2 {
234 (0..d).map(|_| rng.next_f64_positive()).collect()
236 } else {
237 let surrogate = Surrogate::fit(&observations, beta);
238 let mut candidate = vec![0.0_f64; d];
241 let mut best_ucb = f64::NEG_INFINITY;
242 let mut best = vec![0.0_f64; d];
243 for _ in 0..N_CANDIDATES {
244 for xi in candidate.iter_mut() {
245 *xi = rng.next_f64_positive();
246 }
247 let ucb = surrogate.acquisition(&candidate);
248 if ucb > best_ucb {
249 best_ucb = ucb;
250 best.copy_from_slice(&candidate);
251 }
252 }
253 best
254 };
255
256 n_evaluations += 1;
257 if let Some(opt_result) = run_one(
258 symbol,
259 candles,
260 config,
261 &metric,
262 &factory,
263 &norm_point,
264 &self.params,
265 ) {
266 let score = metric.score(&opt_result.result);
267 if score.is_finite() {
268 update_best(&mut best_score, score);
269 observations.push((norm_point, score));
270 }
271 if let Some(b) = best_score {
272 convergence_curve.push(b);
273 }
274 all_results.push(opt_result);
275 }
276 }
277
278 if all_results.is_empty() {
281 return Err(BacktestError::invalid_param(
282 "candles",
283 "no parameter set had enough data to run a backtest",
284 ));
285 }
286
287 sort_results_best_first(&mut all_results, metric);
288
289 if metric.score(&all_results[0].result).is_nan() {
290 return Err(BacktestError::invalid_param(
291 "metric",
292 "all parameter sets produced NaN for the target metric",
293 ));
294 }
295
296 let strategy_name = all_results[0].result.strategy_name.clone();
297 let best = all_results[0].clone();
298 let total_combinations = all_results.len();
299
300 Ok(OptimizationReport {
301 strategy_name,
302 total_combinations,
303 results: all_results,
304 best,
305 skipped_errors: 0,
306 convergence_curve,
307 n_evaluations,
308 })
309 }
310}
311
312#[inline]
315fn update_best(best: &mut Option<f64>, score: f64) {
316 match best {
317 None => *best = Some(score),
318 Some(b) if score > *b => *b = score,
319 _ => {}
320 }
321}
322
323fn run_one<S, F>(
326 symbol: &str,
327 candles: &[Candle],
328 config: &BacktestConfig,
329 _metric: &OptimizeMetric,
330 factory: &F,
331 norm_point: &[f64],
332 param_specs: &[(String, ParamRange)],
333) -> Option<OptimizationResult>
334where
335 S: Strategy,
336 F: Fn(&HashMap<String, ParamValue>) -> S,
337{
338 let params = denormalize(norm_point, param_specs);
339 let strategy = factory(¶ms);
340 match BacktestEngine::new(config.clone()).run(symbol, candles, strategy) {
341 Ok(result) => Some(OptimizationResult { params, result }),
342 Err(BacktestError::InsufficientData { .. }) => None,
343 Err(e) => {
344 tracing::warn!(
345 params = ?params,
346 error = %e,
347 "BayesianSearch: skipping candidate due to unexpected error"
348 );
349 None
350 }
351 }
352}
353
354fn denormalize(
356 norm_point: &[f64],
357 param_specs: &[(String, ParamRange)],
358) -> HashMap<String, ParamValue> {
359 norm_point
360 .iter()
361 .zip(param_specs.iter())
362 .map(|(&t, (name, range))| (name.clone(), range.sample_at(t)))
363 .collect()
364}
365
366fn latin_hypercube_sample(n: usize, d: usize, rng: &mut Xorshift64) -> Vec<Vec<f64>> {
375 if n == 0 {
376 return vec![];
377 }
378
379 let mut samples = vec![vec![0.0_f64; d]; n];
380
381 #[allow(clippy::needless_range_loop)]
382 for dim in 0..d {
383 let mut stratum_values: Vec<f64> = (0..n)
385 .map(|i| {
386 let lo = i as f64 / n as f64;
387 let hi = (i + 1) as f64 / n as f64;
388 lo + rng.next_f64_positive() * (hi - lo)
389 })
390 .collect();
391
392 for i in (1..n).rev() {
394 let j = rng.next_usize(i + 1);
395 stratum_values.swap(i, j);
396 }
397
398 for i in 0..n {
399 samples[i][dim] = stratum_values[i];
400 }
401 }
402
403 samples
404}
405
406struct Surrogate<'a> {
419 observations: &'a [(Vec<f64>, f64)],
420 beta: f64,
421 bandwidth_sq: f64,
423}
424
425impl<'a> Surrogate<'a> {
426 fn fit(observations: &'a [(Vec<f64>, f64)], beta: f64) -> Self {
431 let n = observations.len() as f64;
432 let d = observations.first().map_or(1, |(x, _)| x.len()) as f64;
433 let h = n.powf(-1.0 / (d + 4.0)).max(0.1);
434 Self {
435 observations,
436 beta,
437 bandwidth_sq: 2.0 * h * h,
438 }
439 }
440
441 fn acquisition(&self, x: &[f64]) -> f64 {
443 let (mean, std) = self.predict(x);
444 mean + self.beta * std
445 }
446
447 fn predict(&self, x: &[f64]) -> (f64, f64) {
456 let mut w_sum = 0.0_f64;
457 let mut mean = 0.0_f64;
458 let mut s = 0.0_f64; for (xi, yi) in self.observations {
461 let w = self.rbf(x, xi);
462 if w < f64::EPSILON {
463 continue;
464 }
465 let w_new = w_sum + w;
466 let delta = yi - mean;
467 mean += (w / w_new) * delta;
468 s += w * delta * (yi - mean);
469 w_sum = w_new;
470 }
471
472 if w_sum < f64::EPSILON {
473 return (0.0, 1.0);
474 }
475
476 let std = (s / w_sum).max(0.0).sqrt();
477 (mean, std)
478 }
479
480 #[inline]
482 fn rbf(&self, x: &[f64], xi: &[f64]) -> f64 {
483 let dist_sq: f64 = x.iter().zip(xi.iter()).map(|(a, b)| (a - b).powi(2)).sum();
484 (-dist_sq / self.bandwidth_sq).exp()
485 }
486}
487
488#[cfg(test)]
491mod tests {
492 use super::*;
493 use crate::backtesting::{BacktestConfig, SmaCrossover};
494 use crate::models::chart::Candle;
495
496 fn make_candles(prices: &[f64]) -> Vec<Candle> {
497 prices
498 .iter()
499 .enumerate()
500 .map(|(i, &p)| Candle {
501 timestamp: i as i64,
502 open: p,
503 high: p * 1.01,
504 low: p * 0.99,
505 close: p,
506 volume: 1_000,
507 adj_close: Some(p),
508 })
509 .collect()
510 }
511
512 fn trending_prices(n: usize) -> Vec<f64> {
513 (0..n).map(|i| 100.0 + i as f64 * 0.5).collect()
514 }
515
516 #[test]
519 fn test_lhs_shape() {
520 let mut rng = Xorshift64::new(1);
521 let samples = latin_hypercube_sample(8, 3, &mut rng);
522 assert_eq!(samples.len(), 8);
523 assert!(samples.iter().all(|p| p.len() == 3));
524 }
525
526 #[test]
527 fn test_lhs_stratification() {
528 let n = 10;
529 let mut rng = Xorshift64::new(99);
530 let samples = latin_hypercube_sample(n, 2, &mut rng);
531
532 for dim in 0..2 {
533 let mut counts = vec![0usize; n];
534 for point in &samples {
535 let stratum = (point[dim] * n as f64).floor() as usize;
536 counts[stratum.min(n - 1)] += 1;
537 }
538 assert!(
539 counts.iter().all(|&c| c == 1),
540 "dim {dim}: expected one sample per stratum, got {counts:?}"
541 );
542 }
543 }
544
545 #[test]
546 fn test_lhs_values_in_unit_cube() {
547 let mut rng = Xorshift64::new(7);
548 for point in latin_hypercube_sample(20, 4, &mut rng) {
549 for v in point {
550 assert!(v > 0.0 && v <= 1.0, "value {v} outside (0, 1]");
551 }
552 }
553 }
554
555 #[test]
558 fn test_surrogate_predicts_near_observation() {
559 let obs = vec![(vec![0.5_f64], 1.0_f64)];
560 let s = Surrogate::fit(&obs, 2.0);
561 let (mean, _) = s.predict(&[0.5]);
562 assert!((mean - 1.0).abs() < 1e-6);
563 }
564
565 #[test]
568 fn test_surrogate_max_uncertainty_fallback_for_very_distant_point() {
569 let obs = vec![(vec![0.0_f64], 0.5_f64), (vec![0.1], 0.6)];
572 let s = Surrogate::fit(&obs, 2.0);
573 let (mean, std) = s.predict(&[100.0]);
574 assert!(
575 (mean - 0.0).abs() < 1e-6,
576 "expected fallback mean=0.0, got {mean}"
577 );
578 assert!(
579 (std - 1.0).abs() < 1e-6,
580 "expected fallback std=1.0, got {std}"
581 );
582 }
583
584 #[test]
587 fn test_surrogate_std_nonzero_with_disagreeing_observations() {
588 let obs = vec![(vec![0.0_f64], 0.1_f64), (vec![0.05], 0.9)];
589 let s = Surrogate::fit(&obs, 2.0);
590 let (_, std) = s.predict(&[0.025]); assert!(
592 std > 0.1,
593 "expected non-trivial std for disagreeing observations, got {std}"
594 );
595 }
596
597 #[test]
598 fn test_acquisition_favours_uncertain_regions_with_high_beta() {
599 let obs = vec![(vec![0.0_f64], 0.5_f64), (vec![0.1], 0.6)];
600 let s = Surrogate::fit(&obs, 10.0); assert!(
602 s.acquisition(&[1.0]) > s.acquisition(&[0.05]),
603 "far point should have higher UCB with β=10"
604 );
605 }
606
607 #[test]
610 fn test_bayesian_search_runs() {
611 let candles = make_candles(&trending_prices(200));
612 let config = BacktestConfig::builder()
613 .commission_pct(0.0)
614 .slippage_pct(0.0)
615 .build()
616 .unwrap();
617
618 let report = BayesianSearch::new()
619 .param("fast", ParamRange::int_bounds(3, 10))
620 .param("slow", ParamRange::int_bounds(10, 30))
621 .optimize_for(OptimizeMetric::TotalReturn)
622 .max_evaluations(20)
623 .seed(1)
624 .run("TEST", &candles, &config, |params| {
625 SmaCrossover::new(
626 params["fast"].as_int() as usize,
627 params["slow"].as_int() as usize,
628 )
629 })
630 .unwrap();
631
632 assert!(!report.results.is_empty());
633 assert_eq!(report.strategy_name, "SMA Crossover");
634 assert!(report.n_evaluations > 0);
635 assert!(!report.convergence_curve.is_empty());
636 }
637
638 #[test]
639 fn test_convergence_curve_is_nondecreasing() {
640 let candles = make_candles(&trending_prices(200));
641 let config = BacktestConfig::builder()
642 .commission_pct(0.0)
643 .slippage_pct(0.0)
644 .build()
645 .unwrap();
646
647 let report = BayesianSearch::new()
648 .param("fast", ParamRange::int_bounds(3, 15))
649 .param("slow", ParamRange::int_bounds(15, 40))
650 .max_evaluations(30)
651 .seed(2)
652 .run("TEST", &candles, &config, |params| {
653 SmaCrossover::new(
654 params["fast"].as_int() as usize,
655 params["slow"].as_int() as usize,
656 )
657 })
658 .unwrap();
659
660 for window in report.convergence_curve.windows(2) {
661 assert!(
662 window[1] >= window[0] - 1e-12,
663 "convergence curve not non-decreasing: {window:?}"
664 );
665 }
666 }
667
668 #[test]
669 fn test_results_sorted_best_first() {
670 let candles = make_candles(&trending_prices(150));
671 let config = BacktestConfig::builder()
672 .commission_pct(0.0)
673 .slippage_pct(0.0)
674 .build()
675 .unwrap();
676
677 let report = BayesianSearch::new()
678 .param("fast", ParamRange::int_bounds(3, 10))
679 .param("slow", ParamRange::int_bounds(10, 25))
680 .optimize_for(OptimizeMetric::TotalReturn)
681 .max_evaluations(15)
682 .seed(3)
683 .run("TEST", &candles, &config, |params| {
684 SmaCrossover::new(
685 params["fast"].as_int() as usize,
686 params["slow"].as_int() as usize,
687 )
688 })
689 .unwrap();
690
691 if report.results.len() > 1 {
692 let first = OptimizeMetric::TotalReturn.score(&report.results[0].result);
693 let second = OptimizeMetric::TotalReturn.score(&report.results[1].result);
694 assert!(first >= second - 1e-12);
695 }
696 }
697
698 #[test]
699 fn test_best_matches_results_first() {
700 let candles = make_candles(&trending_prices(150));
701 let config = BacktestConfig::builder()
702 .commission_pct(0.0)
703 .slippage_pct(0.0)
704 .build()
705 .unwrap();
706
707 let report = BayesianSearch::new()
708 .param("fast", ParamRange::int_bounds(3, 10))
709 .param("slow", ParamRange::int_bounds(10, 25))
710 .max_evaluations(15)
711 .seed(4)
712 .run("TEST", &candles, &config, |params| {
713 SmaCrossover::new(
714 params["fast"].as_int() as usize,
715 params["slow"].as_int() as usize,
716 )
717 })
718 .unwrap();
719
720 let best = OptimizeMetric::SharpeRatio.score(&report.best.result);
721 let first = OptimizeMetric::SharpeRatio.score(&report.results[0].result);
722 assert!((best - first).abs() < 1e-12);
723 }
724
725 #[test]
726 fn test_no_params_returns_error() {
727 let candles = make_candles(&trending_prices(100));
728 let config = BacktestConfig::default();
729 assert!(
730 BayesianSearch::new()
731 .run("TEST", &candles, &config, |_| SmaCrossover::new(5, 20))
732 .is_err()
733 );
734 }
735
736 #[test]
737 fn test_seeded_runs_are_reproducible() {
738 let candles = make_candles(&trending_prices(200));
739 let config = BacktestConfig::builder()
740 .commission_pct(0.0)
741 .slippage_pct(0.0)
742 .build()
743 .unwrap();
744
745 let search = BayesianSearch::new()
746 .param("fast", ParamRange::int_bounds(3, 12))
747 .param("slow", ParamRange::int_bounds(12, 30))
748 .max_evaluations(15)
749 .seed(77);
750
751 let factory = |p: &HashMap<String, ParamValue>| {
752 SmaCrossover::new(p["fast"].as_int() as usize, p["slow"].as_int() as usize)
753 };
754
755 let r1 = search
756 .clone()
757 .run("TEST", &candles, &config, factory)
758 .unwrap();
759 let r2 = search.run("TEST", &candles, &config, factory).unwrap();
760
761 assert_eq!(r1.n_evaluations, r2.n_evaluations);
762 assert_eq!(r1.convergence_curve, r2.convergence_curve);
763 assert_eq!(
764 r1.best.result.metrics.total_return_pct,
765 r2.best.result.metrics.total_return_pct
766 );
767 }
768}