1use std::collections::HashMap;
55
56use crate::models::chart::Candle;
57
58use super::super::config::BacktestConfig;
59use super::super::engine::BacktestEngine;
60use super::super::error::{BacktestError, Result};
61use super::super::monte_carlo::Xorshift64;
62use super::super::strategy::Strategy;
63use super::{
64 OptimizationReport, OptimizationResult, OptimizeMetric, ParamRange, ParamValue,
65 sort_results_best_first,
66};
67
68const DEFAULT_MAX_EVALUATIONS: usize = 100;
71const DEFAULT_INITIAL_POINTS: usize = 10;
72const DEFAULT_UCB_BETA: f64 = 2.0;
74const DEFAULT_SEED: u64 = 42;
75const N_CANDIDATES: usize = 1_000;
78
79#[derive(Debug, Clone, Default)]
97pub struct BayesianSearch {
98 params: Vec<(String, ParamRange)>,
99 metric: Option<OptimizeMetric>,
100 max_evaluations: Option<usize>,
101 initial_points: Option<usize>,
102 ucb_beta: Option<f64>,
103 seed: Option<u64>,
104}
105
106impl BayesianSearch {
107 pub fn new() -> Self {
109 Self::default()
110 }
111
112 pub fn param(mut self, name: impl Into<String>, range: ParamRange) -> Self {
117 self.params.push((name.into(), range));
118 self
119 }
120
121 pub fn optimize_for(mut self, metric: OptimizeMetric) -> Self {
123 self.metric = Some(metric);
124 self
125 }
126
127 pub fn max_evaluations(mut self, n: usize) -> Self {
129 self.max_evaluations = Some(n);
130 self
131 }
132
133 pub fn initial_points(mut self, n: usize) -> Self {
138 self.initial_points = Some(n);
139 self
140 }
141
142 pub fn ucb_beta(mut self, beta: f64) -> Self {
147 self.ucb_beta = Some(beta);
148 self
149 }
150
151 pub fn seed(mut self, seed: u64) -> Self {
153 self.seed = Some(seed);
154 self
155 }
156
157 pub fn run<S, F>(
168 &self,
169 symbol: &str,
170 candles: &[Candle],
171 config: &BacktestConfig,
172 factory: F,
173 ) -> Result<OptimizationReport>
174 where
175 S: Strategy,
176 F: Fn(&HashMap<String, ParamValue>) -> S,
177 {
178 if self.params.is_empty() {
179 return Err(BacktestError::invalid_param(
180 "params",
181 "BayesianSearch requires at least one parameter range",
182 ));
183 }
184
185 let d = self.params.len();
186 let metric = self.metric.unwrap_or(OptimizeMetric::SharpeRatio);
187 let max_eval = self.max_evaluations.unwrap_or(DEFAULT_MAX_EVALUATIONS);
188 let n_init = self
189 .initial_points
190 .unwrap_or(DEFAULT_INITIAL_POINTS)
191 .max(2)
192 .min(max_eval);
193 let beta = self.ucb_beta.unwrap_or(DEFAULT_UCB_BETA);
194 let seed = self.seed.unwrap_or(DEFAULT_SEED);
195
196 let mut rng = Xorshift64::new(seed);
197 let mut observations: Vec<(Vec<f64>, f64)> = Vec::with_capacity(max_eval);
199 let mut all_results: Vec<OptimizationResult> = Vec::with_capacity(max_eval);
200 let mut convergence_curve: Vec<f64> = Vec::with_capacity(max_eval);
202 let mut n_evaluations: usize = 0;
203 let mut best_score: Option<f64> = None;
204
205 for norm_point in latin_hypercube_sample(n_init, d, &mut rng) {
208 n_evaluations += 1;
209 if let Some(opt_result) = run_one(
210 symbol,
211 candles,
212 config,
213 &metric,
214 &factory,
215 &norm_point,
216 &self.params,
217 ) {
218 let score = metric.score(&opt_result.result);
219 if score.is_finite() {
220 update_best(&mut best_score, score);
221 observations.push((norm_point, score));
222 }
223 if let Some(b) = best_score {
224 convergence_curve.push(b);
225 }
226 all_results.push(opt_result);
227 }
228 }
229
230 for _ in 0..max_eval.saturating_sub(n_init) {
233 let norm_point = if observations.len() < 2 {
234 (0..d).map(|_| rng.next_f64_positive()).collect()
236 } else {
237 let surrogate = Surrogate::fit(&observations, beta);
238 let mut candidate = vec![0.0_f64; d];
241 let mut best_ucb = f64::NEG_INFINITY;
242 let mut best = vec![0.0_f64; d];
243 for _ in 0..N_CANDIDATES {
244 for xi in candidate.iter_mut() {
245 *xi = rng.next_f64_positive();
246 }
247 let ucb = surrogate.acquisition(&candidate);
248 if ucb > best_ucb {
249 best_ucb = ucb;
250 best.copy_from_slice(&candidate);
251 }
252 }
253 best
254 };
255
256 n_evaluations += 1;
257 if let Some(opt_result) = run_one(
258 symbol,
259 candles,
260 config,
261 &metric,
262 &factory,
263 &norm_point,
264 &self.params,
265 ) {
266 let score = metric.score(&opt_result.result);
267 if score.is_finite() {
268 update_best(&mut best_score, score);
269 observations.push((norm_point, score));
270 }
271 if let Some(b) = best_score {
272 convergence_curve.push(b);
273 }
274 all_results.push(opt_result);
275 }
276 }
277
278 if all_results.is_empty() {
281 return Err(BacktestError::invalid_param(
282 "candles",
283 "no parameter set had enough data to run a backtest",
284 ));
285 }
286
287 sort_results_best_first(&mut all_results, metric);
288
289 if metric.score(&all_results[0].result).is_nan() {
290 return Err(BacktestError::invalid_param(
291 "metric",
292 "all parameter sets produced NaN for the target metric",
293 ));
294 }
295
296 let strategy_name = all_results[0].result.strategy_name.clone();
297 let best = all_results[0].clone();
298 let total_combinations = all_results.len();
299
300 Ok(OptimizationReport {
301 strategy_name,
302 total_combinations,
303 results: all_results,
304 best,
305 skipped_errors: 0,
306 convergence_curve,
307 n_evaluations,
308 })
309 }
310}
311
312#[inline]
315fn update_best(best: &mut Option<f64>, score: f64) {
316 match best {
317 None => *best = Some(score),
318 Some(b) if score > *b => *b = score,
319 _ => {}
320 }
321}
322
323fn run_one<S, F>(
326 symbol: &str,
327 candles: &[Candle],
328 config: &BacktestConfig,
329 _metric: &OptimizeMetric,
330 factory: &F,
331 norm_point: &[f64],
332 param_specs: &[(String, ParamRange)],
333) -> Option<OptimizationResult>
334where
335 S: Strategy,
336 F: Fn(&HashMap<String, ParamValue>) -> S,
337{
338 let params = denormalize(norm_point, param_specs);
339 let strategy = factory(¶ms);
340 match BacktestEngine::new(config.clone()).run(symbol, candles, strategy) {
341 Ok(result) => Some(OptimizationResult { params, result }),
342 Err(BacktestError::InsufficientData { .. }) => None,
343 Err(e) => {
344 tracing::warn!(
345 params = ?params,
346 error = %e,
347 "BayesianSearch: skipping candidate due to unexpected error"
348 );
349 None
350 }
351 }
352}
353
354fn denormalize(
356 norm_point: &[f64],
357 param_specs: &[(String, ParamRange)],
358) -> HashMap<String, ParamValue> {
359 norm_point
360 .iter()
361 .zip(param_specs.iter())
362 .map(|(&t, (name, range))| (name.clone(), range.sample_at(t)))
363 .collect()
364}
365
366fn latin_hypercube_sample(n: usize, d: usize, rng: &mut Xorshift64) -> Vec<Vec<f64>> {
375 if n == 0 {
376 return vec![];
377 }
378
379 let mut samples = vec![vec![0.0_f64; d]; n];
380
381 #[allow(clippy::needless_range_loop)]
382 for dim in 0..d {
383 let mut stratum_values: Vec<f64> = (0..n)
385 .map(|i| {
386 let lo = i as f64 / n as f64;
387 let hi = (i + 1) as f64 / n as f64;
388 lo + rng.next_f64_positive() * (hi - lo)
389 })
390 .collect();
391
392 for i in (1..n).rev() {
394 let j = rng.next_usize(i + 1);
395 stratum_values.swap(i, j);
396 }
397
398 for i in 0..n {
399 samples[i][dim] = stratum_values[i];
400 }
401 }
402
403 samples
404}
405
406struct Surrogate<'a> {
419 observations: &'a [(Vec<f64>, f64)],
420 beta: f64,
421 bandwidth_sq: f64,
423}
424
425impl<'a> Surrogate<'a> {
426 fn fit(observations: &'a [(Vec<f64>, f64)], beta: f64) -> Self {
431 let n = observations.len() as f64;
432 let d = observations.first().map_or(1, |(x, _)| x.len()) as f64;
433 let h = n.powf(-1.0 / (d + 4.0)).max(0.1);
434 Self {
435 observations,
436 beta,
437 bandwidth_sq: 2.0 * h * h,
438 }
439 }
440
441 fn acquisition(&self, x: &[f64]) -> f64 {
443 let (mean, std) = self.predict(x);
444 mean + self.beta * std
445 }
446
447 fn predict(&self, x: &[f64]) -> (f64, f64) {
456 let mut w_sum = 0.0_f64;
457 let mut mean = 0.0_f64;
458 let mut s = 0.0_f64; for (xi, yi) in self.observations {
461 let w = self.rbf(x, xi);
462 if w < f64::EPSILON {
463 continue;
464 }
465 let w_new = w_sum + w;
466 let delta = yi - mean;
467 mean += (w / w_new) * delta;
468 s += w * delta * (yi - mean);
469 w_sum = w_new;
470 }
471
472 if w_sum < f64::EPSILON {
473 return (0.0, 1.0);
474 }
475
476 let std = (s / w_sum).max(0.0).sqrt();
477 (mean, std)
478 }
479
480 #[inline]
482 fn rbf(&self, x: &[f64], xi: &[f64]) -> f64 {
483 let dist_sq: f64 = x.iter().zip(xi.iter()).map(|(a, b)| (a - b).powi(2)).sum();
484 (-dist_sq / self.bandwidth_sq).exp()
485 }
486}
487
488#[cfg(test)]
491mod tests {
492 use super::*;
493 use crate::backtesting::{BacktestConfig, SmaCrossover};
494 use crate::models::chart::Candle;
495
496 fn make_candles(prices: &[f64]) -> Vec<Candle> {
497 prices
498 .iter()
499 .enumerate()
500 .map(|(i, &p)| Candle {
501 timestamp: i as i64,
502 open: p,
503 high: p * 1.01,
504 low: p * 0.99,
505 close: p,
506 volume: 1_000,
507 adj_close: Some(p),
508 provider_id: None,
509 })
510 .collect()
511 }
512
513 fn trending_prices(n: usize) -> Vec<f64> {
514 (0..n).map(|i| 100.0 + i as f64 * 0.5).collect()
515 }
516
517 #[test]
520 fn test_lhs_shape() {
521 let mut rng = Xorshift64::new(1);
522 let samples = latin_hypercube_sample(8, 3, &mut rng);
523 assert_eq!(samples.len(), 8);
524 assert!(samples.iter().all(|p| p.len() == 3));
525 }
526
527 #[test]
528 fn test_lhs_stratification() {
529 let n = 10;
530 let mut rng = Xorshift64::new(99);
531 let samples = latin_hypercube_sample(n, 2, &mut rng);
532
533 for dim in 0..2 {
534 let mut counts = vec![0usize; n];
535 for point in &samples {
536 let stratum = (point[dim] * n as f64).floor() as usize;
537 counts[stratum.min(n - 1)] += 1;
538 }
539 assert!(
540 counts.iter().all(|&c| c == 1),
541 "dim {dim}: expected one sample per stratum, got {counts:?}"
542 );
543 }
544 }
545
546 #[test]
547 fn test_lhs_values_in_unit_cube() {
548 let mut rng = Xorshift64::new(7);
549 for point in latin_hypercube_sample(20, 4, &mut rng) {
550 for v in point {
551 assert!(v > 0.0 && v <= 1.0, "value {v} outside (0, 1]");
552 }
553 }
554 }
555
556 #[test]
559 fn test_surrogate_predicts_near_observation() {
560 let obs = vec![(vec![0.5_f64], 1.0_f64)];
561 let s = Surrogate::fit(&obs, 2.0);
562 let (mean, _) = s.predict(&[0.5]);
563 assert!((mean - 1.0).abs() < 1e-6);
564 }
565
566 #[test]
569 fn test_surrogate_max_uncertainty_fallback_for_very_distant_point() {
570 let obs = vec![(vec![0.0_f64], 0.5_f64), (vec![0.1], 0.6)];
573 let s = Surrogate::fit(&obs, 2.0);
574 let (mean, std) = s.predict(&[100.0]);
575 assert!(
576 (mean - 0.0).abs() < 1e-6,
577 "expected fallback mean=0.0, got {mean}"
578 );
579 assert!(
580 (std - 1.0).abs() < 1e-6,
581 "expected fallback std=1.0, got {std}"
582 );
583 }
584
585 #[test]
588 fn test_surrogate_std_nonzero_with_disagreeing_observations() {
589 let obs = vec![(vec![0.0_f64], 0.1_f64), (vec![0.05], 0.9)];
590 let s = Surrogate::fit(&obs, 2.0);
591 let (_, std) = s.predict(&[0.025]); assert!(
593 std > 0.1,
594 "expected non-trivial std for disagreeing observations, got {std}"
595 );
596 }
597
598 #[test]
599 fn test_acquisition_favours_uncertain_regions_with_high_beta() {
600 let obs = vec![(vec![0.0_f64], 0.5_f64), (vec![0.1], 0.6)];
601 let s = Surrogate::fit(&obs, 10.0); assert!(
603 s.acquisition(&[1.0]) > s.acquisition(&[0.05]),
604 "far point should have higher UCB with β=10"
605 );
606 }
607
608 #[test]
611 fn test_bayesian_search_runs() {
612 let candles = make_candles(&trending_prices(200));
613 let config = BacktestConfig::builder()
614 .commission_pct(0.0)
615 .slippage_pct(0.0)
616 .build()
617 .unwrap();
618
619 let report = BayesianSearch::new()
620 .param("fast", ParamRange::int_bounds(3, 10))
621 .param("slow", ParamRange::int_bounds(10, 30))
622 .optimize_for(OptimizeMetric::TotalReturn)
623 .max_evaluations(20)
624 .seed(1)
625 .run("TEST", &candles, &config, |params| {
626 SmaCrossover::new(
627 params["fast"].as_int() as usize,
628 params["slow"].as_int() as usize,
629 )
630 })
631 .unwrap();
632
633 assert!(!report.results.is_empty());
634 assert_eq!(report.strategy_name, "SMA Crossover");
635 assert!(report.n_evaluations > 0);
636 assert!(!report.convergence_curve.is_empty());
637 }
638
639 #[test]
640 fn test_convergence_curve_is_nondecreasing() {
641 let candles = make_candles(&trending_prices(200));
642 let config = BacktestConfig::builder()
643 .commission_pct(0.0)
644 .slippage_pct(0.0)
645 .build()
646 .unwrap();
647
648 let report = BayesianSearch::new()
649 .param("fast", ParamRange::int_bounds(3, 15))
650 .param("slow", ParamRange::int_bounds(15, 40))
651 .max_evaluations(30)
652 .seed(2)
653 .run("TEST", &candles, &config, |params| {
654 SmaCrossover::new(
655 params["fast"].as_int() as usize,
656 params["slow"].as_int() as usize,
657 )
658 })
659 .unwrap();
660
661 for window in report.convergence_curve.windows(2) {
662 assert!(
663 window[1] >= window[0] - 1e-12,
664 "convergence curve not non-decreasing: {window:?}"
665 );
666 }
667 }
668
669 #[test]
670 fn test_results_sorted_best_first() {
671 let candles = make_candles(&trending_prices(150));
672 let config = BacktestConfig::builder()
673 .commission_pct(0.0)
674 .slippage_pct(0.0)
675 .build()
676 .unwrap();
677
678 let report = BayesianSearch::new()
679 .param("fast", ParamRange::int_bounds(3, 10))
680 .param("slow", ParamRange::int_bounds(10, 25))
681 .optimize_for(OptimizeMetric::TotalReturn)
682 .max_evaluations(15)
683 .seed(3)
684 .run("TEST", &candles, &config, |params| {
685 SmaCrossover::new(
686 params["fast"].as_int() as usize,
687 params["slow"].as_int() as usize,
688 )
689 })
690 .unwrap();
691
692 if report.results.len() > 1 {
693 let first = OptimizeMetric::TotalReturn.score(&report.results[0].result);
694 let second = OptimizeMetric::TotalReturn.score(&report.results[1].result);
695 assert!(first >= second - 1e-12);
696 }
697 }
698
699 #[test]
700 fn test_best_matches_results_first() {
701 let candles = make_candles(&trending_prices(150));
702 let config = BacktestConfig::builder()
703 .commission_pct(0.0)
704 .slippage_pct(0.0)
705 .build()
706 .unwrap();
707
708 let report = BayesianSearch::new()
709 .param("fast", ParamRange::int_bounds(3, 10))
710 .param("slow", ParamRange::int_bounds(10, 25))
711 .max_evaluations(15)
712 .seed(4)
713 .run("TEST", &candles, &config, |params| {
714 SmaCrossover::new(
715 params["fast"].as_int() as usize,
716 params["slow"].as_int() as usize,
717 )
718 })
719 .unwrap();
720
721 let best = OptimizeMetric::SharpeRatio.score(&report.best.result);
722 let first = OptimizeMetric::SharpeRatio.score(&report.results[0].result);
723 assert!((best - first).abs() < 1e-12);
724 }
725
726 #[test]
727 fn test_no_params_returns_error() {
728 let candles = make_candles(&trending_prices(100));
729 let config = BacktestConfig::default();
730 assert!(
731 BayesianSearch::new()
732 .run("TEST", &candles, &config, |_| SmaCrossover::new(5, 20))
733 .is_err()
734 );
735 }
736
737 #[test]
738 fn test_seeded_runs_are_reproducible() {
739 let candles = make_candles(&trending_prices(200));
740 let config = BacktestConfig::builder()
741 .commission_pct(0.0)
742 .slippage_pct(0.0)
743 .build()
744 .unwrap();
745
746 let search = BayesianSearch::new()
747 .param("fast", ParamRange::int_bounds(3, 12))
748 .param("slow", ParamRange::int_bounds(12, 30))
749 .max_evaluations(15)
750 .seed(77);
751
752 let factory = |p: &HashMap<String, ParamValue>| {
753 SmaCrossover::new(p["fast"].as_int() as usize, p["slow"].as_int() as usize)
754 };
755
756 let r1 = search
757 .clone()
758 .run("TEST", &candles, &config, factory)
759 .unwrap();
760 let r2 = search.run("TEST", &candles, &config, factory).unwrap();
761
762 assert_eq!(r1.n_evaluations, r2.n_evaluations);
763 assert_eq!(r1.convergence_curve, r2.convergence_curve);
764 assert_eq!(
765 r1.best.result.metrics.total_return_pct,
766 r2.best.result.metrics.total_return_pct
767 );
768 }
769}