1use std::collections::HashMap;
55
56use crate::models::chart::Candle;
57
58use super::super::config::BacktestConfig;
59use super::super::engine::BacktestEngine;
60use super::super::error::{BacktestError, Result};
61use super::super::monte_carlo::Xorshift64;
62use super::super::strategy::Strategy;
63use super::{
64 OptimizationReport, OptimizationResult, OptimizeMetric, ParamRange, ParamValue,
65 sort_results_best_first,
66};
67
68const DEFAULT_MAX_EVALUATIONS: usize = 100;
71const DEFAULT_INITIAL_POINTS: usize = 10;
72const DEFAULT_UCB_BETA: f64 = 2.0;
74const DEFAULT_SEED: u64 = 42;
75const N_CANDIDATES: usize = 1_000;
78
79#[derive(Debug, Clone, Default)]
97pub struct BayesianSearch {
98 params: Vec<(String, ParamRange)>,
99 metric: Option<OptimizeMetric>,
100 max_evaluations: Option<usize>,
101 initial_points: Option<usize>,
102 ucb_beta: Option<f64>,
103 seed: Option<u64>,
104}
105
106impl BayesianSearch {
107 pub fn new() -> Self {
109 Self::default()
110 }
111
112 pub fn param(mut self, name: impl Into<String>, range: ParamRange) -> Self {
117 self.params.push((name.into(), range));
118 self
119 }
120
121 pub fn optimize_for(mut self, metric: OptimizeMetric) -> Self {
123 self.metric = Some(metric);
124 self
125 }
126
127 pub fn max_evaluations(mut self, n: usize) -> Self {
129 self.max_evaluations = Some(n);
130 self
131 }
132
133 pub fn initial_points(mut self, n: usize) -> Self {
138 self.initial_points = Some(n);
139 self
140 }
141
142 pub fn ucb_beta(mut self, beta: f64) -> Self {
147 self.ucb_beta = Some(beta);
148 self
149 }
150
151 pub fn seed(mut self, seed: u64) -> Self {
153 self.seed = Some(seed);
154 self
155 }
156
157 pub fn run<S, F>(
168 &self,
169 symbol: &str,
170 candles: &[Candle],
171 config: &BacktestConfig,
172 factory: F,
173 ) -> Result<OptimizationReport>
174 where
175 S: Strategy,
176 F: Fn(&HashMap<String, ParamValue>) -> S,
177 {
178 if self.params.is_empty() {
179 return Err(BacktestError::invalid_param(
180 "params",
181 "BayesianSearch requires at least one parameter range",
182 ));
183 }
184
185 let d = self.params.len();
186 let metric = self.metric.unwrap_or(OptimizeMetric::SharpeRatio);
187 let max_eval = self.max_evaluations.unwrap_or(DEFAULT_MAX_EVALUATIONS);
188 let n_init = self
189 .initial_points
190 .unwrap_or(DEFAULT_INITIAL_POINTS)
191 .max(2)
192 .min(max_eval);
193 let beta = self.ucb_beta.unwrap_or(DEFAULT_UCB_BETA);
194 let seed = self.seed.unwrap_or(DEFAULT_SEED);
195
196 let mut rng = Xorshift64::new(seed);
197 let mut observations: Vec<(Vec<f64>, f64)> = Vec::with_capacity(max_eval);
199 let mut all_results: Vec<OptimizationResult> = Vec::with_capacity(max_eval);
200 let mut convergence_curve: Vec<f64> = Vec::with_capacity(max_eval);
202 let mut n_evaluations: usize = 0;
203 let mut best_score: Option<f64> = None;
204
205 for norm_point in latin_hypercube_sample(n_init, d, &mut rng) {
208 n_evaluations += 1;
209 if let Some(opt_result) = run_one(
210 symbol,
211 candles,
212 config,
213 &metric,
214 &factory,
215 &norm_point,
216 &self.params,
217 ) {
218 let score = metric.score(&opt_result.result);
219 if score.is_finite() {
220 update_best(&mut best_score, score);
221 observations.push((norm_point, score));
222 }
223 if let Some(b) = best_score {
224 convergence_curve.push(b);
225 }
226 all_results.push(opt_result);
227 }
228 }
229
230 for _ in 0..max_eval.saturating_sub(n_init) {
233 let norm_point = if observations.len() < 2 {
234 (0..d).map(|_| rng.next_f64_positive()).collect()
236 } else {
237 let surrogate = Surrogate::fit(&observations, beta);
238 (0..N_CANDIDATES)
239 .map(|_| {
240 (0..d)
241 .map(|_| rng.next_f64_positive())
242 .collect::<Vec<f64>>()
243 })
244 .max_by(|a, b| {
245 surrogate
246 .acquisition(a)
247 .partial_cmp(&surrogate.acquisition(b))
248 .unwrap_or(std::cmp::Ordering::Equal)
249 })
250 .unwrap()
252 };
253
254 n_evaluations += 1;
255 if let Some(opt_result) = run_one(
256 symbol,
257 candles,
258 config,
259 &metric,
260 &factory,
261 &norm_point,
262 &self.params,
263 ) {
264 let score = metric.score(&opt_result.result);
265 if score.is_finite() {
266 update_best(&mut best_score, score);
267 observations.push((norm_point, score));
268 }
269 if let Some(b) = best_score {
270 convergence_curve.push(b);
271 }
272 all_results.push(opt_result);
273 }
274 }
275
276 if all_results.is_empty() {
279 return Err(BacktestError::invalid_param(
280 "candles",
281 "no parameter set had enough data to run a backtest",
282 ));
283 }
284
285 sort_results_best_first(&mut all_results, metric);
286
287 if metric.score(&all_results[0].result).is_nan() {
288 return Err(BacktestError::invalid_param(
289 "metric",
290 "all parameter sets produced NaN for the target metric",
291 ));
292 }
293
294 let strategy_name = all_results[0].result.strategy_name.clone();
295 let best = all_results[0].clone();
296 let total_combinations = all_results.len();
297
298 Ok(OptimizationReport {
299 strategy_name,
300 total_combinations,
301 results: all_results,
302 best,
303 skipped_errors: 0,
304 convergence_curve,
305 n_evaluations,
306 })
307 }
308}
309
310#[inline]
313fn update_best(best: &mut Option<f64>, score: f64) {
314 match best {
315 None => *best = Some(score),
316 Some(b) if score > *b => *b = score,
317 _ => {}
318 }
319}
320
321fn run_one<S, F>(
324 symbol: &str,
325 candles: &[Candle],
326 config: &BacktestConfig,
327 _metric: &OptimizeMetric,
328 factory: &F,
329 norm_point: &[f64],
330 param_specs: &[(String, ParamRange)],
331) -> Option<OptimizationResult>
332where
333 S: Strategy,
334 F: Fn(&HashMap<String, ParamValue>) -> S,
335{
336 let params = denormalize(norm_point, param_specs);
337 let strategy = factory(¶ms);
338 match BacktestEngine::new(config.clone()).run(symbol, candles, strategy) {
339 Ok(result) => Some(OptimizationResult { params, result }),
340 Err(BacktestError::InsufficientData { .. }) => None,
341 Err(e) => {
342 tracing::warn!(
343 params = ?params,
344 error = %e,
345 "BayesianSearch: skipping candidate due to unexpected error"
346 );
347 None
348 }
349 }
350}
351
352fn denormalize(
354 norm_point: &[f64],
355 param_specs: &[(String, ParamRange)],
356) -> HashMap<String, ParamValue> {
357 norm_point
358 .iter()
359 .zip(param_specs.iter())
360 .map(|(&t, (name, range))| (name.clone(), range.sample_at(t)))
361 .collect()
362}
363
364fn latin_hypercube_sample(n: usize, d: usize, rng: &mut Xorshift64) -> Vec<Vec<f64>> {
373 if n == 0 {
374 return vec![];
375 }
376
377 let mut samples = vec![vec![0.0_f64; d]; n];
378
379 #[allow(clippy::needless_range_loop)]
380 for dim in 0..d {
381 let mut stratum_values: Vec<f64> = (0..n)
383 .map(|i| {
384 let lo = i as f64 / n as f64;
385 let hi = (i + 1) as f64 / n as f64;
386 lo + rng.next_f64_positive() * (hi - lo)
387 })
388 .collect();
389
390 for i in (1..n).rev() {
392 let j = rng.next_usize(i + 1);
393 stratum_values.swap(i, j);
394 }
395
396 for i in 0..n {
397 samples[i][dim] = stratum_values[i];
398 }
399 }
400
401 samples
402}
403
404struct Surrogate<'a> {
417 observations: &'a [(Vec<f64>, f64)],
418 beta: f64,
419 bandwidth_sq: f64,
421}
422
423impl<'a> Surrogate<'a> {
424 fn fit(observations: &'a [(Vec<f64>, f64)], beta: f64) -> Self {
429 let n = observations.len() as f64;
430 let d = observations.first().map_or(1, |(x, _)| x.len()) as f64;
431 let h = n.powf(-1.0 / (d + 4.0)).max(0.1);
432 Self {
433 observations,
434 beta,
435 bandwidth_sq: 2.0 * h * h,
436 }
437 }
438
439 fn acquisition(&self, x: &[f64]) -> f64 {
441 let (mean, std) = self.predict(x);
442 mean + self.beta * std
443 }
444
445 fn predict(&self, x: &[f64]) -> (f64, f64) {
450 let mut w_sum = 0.0_f64;
451 let mut wy_sum = 0.0_f64;
452
453 for (xi, yi) in self.observations {
454 let w = self.rbf(x, xi);
455 w_sum += w;
456 wy_sum += w * yi;
457 }
458
459 if w_sum < f64::EPSILON {
460 return (0.0, 1.0);
461 }
462
463 let mean = wy_sum / w_sum;
464
465 let mut wvar = 0.0_f64;
466 for (xi, yi) in self.observations {
467 let diff = yi - mean;
468 wvar += self.rbf(x, xi) * diff * diff;
469 }
470 let std = (wvar / w_sum).max(0.0).sqrt();
471
472 (mean, std)
473 }
474
475 #[inline]
477 fn rbf(&self, x: &[f64], xi: &[f64]) -> f64 {
478 let dist_sq: f64 = x.iter().zip(xi.iter()).map(|(a, b)| (a - b).powi(2)).sum();
479 (-dist_sq / self.bandwidth_sq).exp()
480 }
481}
482
483#[cfg(test)]
486mod tests {
487 use super::*;
488 use crate::backtesting::{BacktestConfig, SmaCrossover};
489 use crate::models::chart::Candle;
490
491 fn make_candles(prices: &[f64]) -> Vec<Candle> {
492 prices
493 .iter()
494 .enumerate()
495 .map(|(i, &p)| Candle {
496 timestamp: i as i64,
497 open: p,
498 high: p * 1.01,
499 low: p * 0.99,
500 close: p,
501 volume: 1_000,
502 adj_close: Some(p),
503 })
504 .collect()
505 }
506
507 fn trending_prices(n: usize) -> Vec<f64> {
508 (0..n).map(|i| 100.0 + i as f64 * 0.5).collect()
509 }
510
511 #[test]
514 fn test_lhs_shape() {
515 let mut rng = Xorshift64::new(1);
516 let samples = latin_hypercube_sample(8, 3, &mut rng);
517 assert_eq!(samples.len(), 8);
518 assert!(samples.iter().all(|p| p.len() == 3));
519 }
520
521 #[test]
522 fn test_lhs_stratification() {
523 let n = 10;
524 let mut rng = Xorshift64::new(99);
525 let samples = latin_hypercube_sample(n, 2, &mut rng);
526
527 for dim in 0..2 {
528 let mut counts = vec![0usize; n];
529 for point in &samples {
530 let stratum = (point[dim] * n as f64).floor() as usize;
531 counts[stratum.min(n - 1)] += 1;
532 }
533 assert!(
534 counts.iter().all(|&c| c == 1),
535 "dim {dim}: expected one sample per stratum, got {counts:?}"
536 );
537 }
538 }
539
540 #[test]
541 fn test_lhs_values_in_unit_cube() {
542 let mut rng = Xorshift64::new(7);
543 for point in latin_hypercube_sample(20, 4, &mut rng) {
544 for v in point {
545 assert!(v > 0.0 && v <= 1.0, "value {v} outside (0, 1]");
546 }
547 }
548 }
549
550 #[test]
553 fn test_surrogate_predicts_near_observation() {
554 let obs = vec![(vec![0.5_f64], 1.0_f64)];
555 let s = Surrogate::fit(&obs, 2.0);
556 let (mean, _) = s.predict(&[0.5]);
557 assert!((mean - 1.0).abs() < 1e-6);
558 }
559
560 #[test]
563 fn test_surrogate_max_uncertainty_fallback_for_very_distant_point() {
564 let obs = vec![(vec![0.0_f64], 0.5_f64), (vec![0.1], 0.6)];
567 let s = Surrogate::fit(&obs, 2.0);
568 let (mean, std) = s.predict(&[100.0]);
569 assert!(
570 (mean - 0.0).abs() < 1e-6,
571 "expected fallback mean=0.0, got {mean}"
572 );
573 assert!(
574 (std - 1.0).abs() < 1e-6,
575 "expected fallback std=1.0, got {std}"
576 );
577 }
578
579 #[test]
582 fn test_surrogate_std_nonzero_with_disagreeing_observations() {
583 let obs = vec![(vec![0.0_f64], 0.1_f64), (vec![0.05], 0.9)];
584 let s = Surrogate::fit(&obs, 2.0);
585 let (_, std) = s.predict(&[0.025]); assert!(
587 std > 0.1,
588 "expected non-trivial std for disagreeing observations, got {std}"
589 );
590 }
591
592 #[test]
593 fn test_acquisition_favours_uncertain_regions_with_high_beta() {
594 let obs = vec![(vec![0.0_f64], 0.5_f64), (vec![0.1], 0.6)];
595 let s = Surrogate::fit(&obs, 10.0); assert!(
597 s.acquisition(&[1.0]) > s.acquisition(&[0.05]),
598 "far point should have higher UCB with β=10"
599 );
600 }
601
602 #[test]
605 fn test_bayesian_search_runs() {
606 let candles = make_candles(&trending_prices(200));
607 let config = BacktestConfig::builder()
608 .commission_pct(0.0)
609 .slippage_pct(0.0)
610 .build()
611 .unwrap();
612
613 let report = BayesianSearch::new()
614 .param("fast", ParamRange::int_bounds(3, 10))
615 .param("slow", ParamRange::int_bounds(10, 30))
616 .optimize_for(OptimizeMetric::TotalReturn)
617 .max_evaluations(20)
618 .seed(1)
619 .run("TEST", &candles, &config, |params| {
620 SmaCrossover::new(
621 params["fast"].as_int() as usize,
622 params["slow"].as_int() as usize,
623 )
624 })
625 .unwrap();
626
627 assert!(!report.results.is_empty());
628 assert_eq!(report.strategy_name, "SMA Crossover");
629 assert!(report.n_evaluations > 0);
630 assert!(!report.convergence_curve.is_empty());
631 }
632
633 #[test]
634 fn test_convergence_curve_is_nondecreasing() {
635 let candles = make_candles(&trending_prices(200));
636 let config = BacktestConfig::builder()
637 .commission_pct(0.0)
638 .slippage_pct(0.0)
639 .build()
640 .unwrap();
641
642 let report = BayesianSearch::new()
643 .param("fast", ParamRange::int_bounds(3, 15))
644 .param("slow", ParamRange::int_bounds(15, 40))
645 .max_evaluations(30)
646 .seed(2)
647 .run("TEST", &candles, &config, |params| {
648 SmaCrossover::new(
649 params["fast"].as_int() as usize,
650 params["slow"].as_int() as usize,
651 )
652 })
653 .unwrap();
654
655 for window in report.convergence_curve.windows(2) {
656 assert!(
657 window[1] >= window[0] - 1e-12,
658 "convergence curve not non-decreasing: {window:?}"
659 );
660 }
661 }
662
663 #[test]
664 fn test_results_sorted_best_first() {
665 let candles = make_candles(&trending_prices(150));
666 let config = BacktestConfig::builder()
667 .commission_pct(0.0)
668 .slippage_pct(0.0)
669 .build()
670 .unwrap();
671
672 let report = BayesianSearch::new()
673 .param("fast", ParamRange::int_bounds(3, 10))
674 .param("slow", ParamRange::int_bounds(10, 25))
675 .optimize_for(OptimizeMetric::TotalReturn)
676 .max_evaluations(15)
677 .seed(3)
678 .run("TEST", &candles, &config, |params| {
679 SmaCrossover::new(
680 params["fast"].as_int() as usize,
681 params["slow"].as_int() as usize,
682 )
683 })
684 .unwrap();
685
686 if report.results.len() > 1 {
687 let first = OptimizeMetric::TotalReturn.score(&report.results[0].result);
688 let second = OptimizeMetric::TotalReturn.score(&report.results[1].result);
689 assert!(first >= second - 1e-12);
690 }
691 }
692
693 #[test]
694 fn test_best_matches_results_first() {
695 let candles = make_candles(&trending_prices(150));
696 let config = BacktestConfig::builder()
697 .commission_pct(0.0)
698 .slippage_pct(0.0)
699 .build()
700 .unwrap();
701
702 let report = BayesianSearch::new()
703 .param("fast", ParamRange::int_bounds(3, 10))
704 .param("slow", ParamRange::int_bounds(10, 25))
705 .max_evaluations(15)
706 .seed(4)
707 .run("TEST", &candles, &config, |params| {
708 SmaCrossover::new(
709 params["fast"].as_int() as usize,
710 params["slow"].as_int() as usize,
711 )
712 })
713 .unwrap();
714
715 let best = OptimizeMetric::SharpeRatio.score(&report.best.result);
716 let first = OptimizeMetric::SharpeRatio.score(&report.results[0].result);
717 assert!((best - first).abs() < 1e-12);
718 }
719
720 #[test]
721 fn test_no_params_returns_error() {
722 let candles = make_candles(&trending_prices(100));
723 let config = BacktestConfig::default();
724 assert!(
725 BayesianSearch::new()
726 .run("TEST", &candles, &config, |_| SmaCrossover::new(5, 20))
727 .is_err()
728 );
729 }
730
731 #[test]
732 fn test_seeded_runs_are_reproducible() {
733 let candles = make_candles(&trending_prices(200));
734 let config = BacktestConfig::builder()
735 .commission_pct(0.0)
736 .slippage_pct(0.0)
737 .build()
738 .unwrap();
739
740 let search = BayesianSearch::new()
741 .param("fast", ParamRange::int_bounds(3, 12))
742 .param("slow", ParamRange::int_bounds(12, 30))
743 .max_evaluations(15)
744 .seed(77);
745
746 let factory = |p: &HashMap<String, ParamValue>| {
747 SmaCrossover::new(p["fast"].as_int() as usize, p["slow"].as_int() as usize)
748 };
749
750 let r1 = search
751 .clone()
752 .run("TEST", &candles, &config, factory)
753 .unwrap();
754 let r2 = search.run("TEST", &candles, &config, factory).unwrap();
755
756 assert_eq!(r1.n_evaluations, r2.n_evaluations);
757 assert_eq!(r1.convergence_curve, r2.convergence_curve);
758 assert_eq!(
759 r1.best.result.metrics.total_return_pct,
760 r2.best.result.metrics.total_return_pct
761 );
762 }
763}