finance_query/backtesting/optimizer/
grid.rs1use std::collections::HashMap;
34use std::sync::atomic::{AtomicUsize, Ordering};
35
36use rayon::prelude::*;
37
38use crate::models::chart::Candle;
39
40use super::super::config::BacktestConfig;
41use super::super::engine::BacktestEngine;
42use super::super::error::{BacktestError, Result};
43use super::super::strategy::Strategy;
44use super::{
45 OptimizationReport, OptimizationResult, OptimizeMetric, ParamRange, ParamValue,
46 sort_results_best_first,
47};
48
49#[derive(Debug, Clone, Default)]
65pub struct GridSearch {
66 params: Vec<(String, ParamRange)>,
68 metric: Option<OptimizeMetric>,
70}
71
72impl GridSearch {
73 pub fn new() -> Self {
75 Self::default()
76 }
77
78 pub fn param(mut self, name: impl Into<String>, range: ParamRange) -> Self {
83 self.params.push((name.into(), range));
84 self
85 }
86
87 pub fn optimize_for(mut self, metric: OptimizeMetric) -> Self {
89 self.metric = Some(metric);
90 self
91 }
92
93 pub fn run<S, F>(
103 &self,
104 symbol: &str,
105 candles: &[Candle],
106 config: &BacktestConfig,
107 factory: F,
108 ) -> Result<OptimizationReport>
109 where
110 S: Strategy + Send,
111 F: Fn(&HashMap<String, ParamValue>) -> S + Send + Sync,
112 {
113 if self.params.is_empty() {
114 return Err(BacktestError::invalid_param(
115 "params",
116 "grid search requires at least one parameter range",
117 ));
118 }
119
120 let metric = self.metric.unwrap_or(OptimizeMetric::SharpeRatio);
121
122 let expanded: Vec<(&str, Vec<ParamValue>)> = self
123 .params
124 .iter()
125 .map(|(name, range)| (name.as_str(), range.expand()))
126 .collect();
127
128 let combinations = cartesian_product(&expanded);
129 let total_combinations = combinations.len();
130
131 if total_combinations == 0 {
132 return Err(BacktestError::invalid_param(
133 "params",
134 "all parameter ranges produced empty value sets \
135 (hint: float_bounds is not compatible with GridSearch — use BayesianSearch)",
136 ));
137 }
138
139 if total_combinations > 10_000 {
140 tracing::warn!(
141 total_combinations,
142 "grid search: large combination count — consider BayesianSearch or wider steps"
143 );
144 }
145
146 let skipped_errors = AtomicUsize::new(0);
147 let mut results: Vec<OptimizationResult> = combinations
148 .into_par_iter()
149 .filter_map(|params| {
150 let strategy = factory(¶ms);
151 match BacktestEngine::new(config.clone()).run(symbol, candles, strategy) {
152 Ok(result) => Some(OptimizationResult { params, result }),
153 Err(BacktestError::InsufficientData { .. }) => None,
154 Err(e) => {
155 tracing::warn!(
156 params = ?params,
157 error = %e,
158 "grid search: skipping combination due to unexpected error"
159 );
160 skipped_errors.fetch_add(1, Ordering::Relaxed);
161 None
162 }
163 }
164 })
165 .collect();
166 let skipped_errors = skipped_errors.into_inner();
167
168 if results.is_empty() {
169 return Err(BacktestError::invalid_param(
170 "candles",
171 "no parameter combination had enough data to run",
172 ));
173 }
174
175 sort_results_best_first(&mut results, metric);
176
177 if metric.score(&results[0].result).is_nan() {
178 return Err(BacktestError::invalid_param(
179 "metric",
180 "all parameter combinations produced NaN for the target metric",
181 ));
182 }
183
184 let strategy_name = results[0].result.strategy_name.clone();
185 let best = results[0].clone();
186 let n_evaluations = total_combinations;
187
188 Ok(OptimizationReport {
189 strategy_name,
190 total_combinations,
191 results,
192 best,
193 skipped_errors,
194 convergence_curve: vec![],
197 n_evaluations,
198 })
199 }
200}
201
202fn cartesian_product(params: &[(&str, Vec<ParamValue>)]) -> Vec<HashMap<String, ParamValue>> {
209 if params.is_empty() {
210 return vec![];
211 }
212
213 let mut result: Vec<HashMap<String, ParamValue>> = vec![HashMap::new()];
214
215 for (name, values) in params {
216 let mut next = Vec::with_capacity(result.len() * values.len());
217 for existing in &result {
218 for value in values {
219 let mut combo = existing.clone();
220 combo.insert(name.to_string(), value.clone());
221 next.push(combo);
222 }
223 }
224 result = next;
225 }
226
227 result
228}
229
230#[cfg(test)]
233mod tests {
234 use super::*;
235 use crate::backtesting::{BacktestConfig, SmaCrossover};
236 use crate::models::chart::Candle;
237
238 fn make_candles(prices: &[f64]) -> Vec<Candle> {
239 prices
240 .iter()
241 .enumerate()
242 .map(|(i, &p)| Candle {
243 timestamp: i as i64,
244 open: p,
245 high: p * 1.01,
246 low: p * 0.99,
247 close: p,
248 volume: 1000,
249 adj_close: Some(p),
250 })
251 .collect()
252 }
253
254 fn trending_prices(n: usize) -> Vec<f64> {
255 (0..n).map(|i| 100.0 + i as f64 * 0.5).collect()
256 }
257
258 #[test]
261 fn test_param_value_conversion() {
262 let iv = ParamValue::Int(10);
263 assert_eq!(iv.as_int(), 10);
264 assert!((iv.as_float() - 10.0).abs() < f64::EPSILON);
265
266 let fv = ParamValue::Float(1.5);
267 assert_eq!(fv.as_int(), 1);
268 assert!((fv.as_float() - 1.5).abs() < f64::EPSILON);
269 }
270
271 #[test]
274 fn test_int_range_expand() {
275 let r = ParamRange::int_range(5, 20, 5);
276 let vals = r.expand();
277 assert_eq!(
278 vals,
279 vec![
280 ParamValue::Int(5),
281 ParamValue::Int(10),
282 ParamValue::Int(15),
283 ParamValue::Int(20),
284 ]
285 );
286 }
287
288 #[test]
289 fn test_float_range_expand() {
290 let r = ParamRange::float_range(0.1, 0.3, 0.1);
291 let vals = r.expand();
292 assert_eq!(vals.len(), 3);
293 assert!((vals[0].as_float() - 0.1).abs() < 1e-9);
294 assert!((vals[2].as_float() - 0.3).abs() < 1e-9);
295 }
296
297 #[test]
300 fn test_float_range_endpoint_clamping() {
301 let vals = ParamRange::float_range(0.1, 0.5, 0.1).expand();
302 assert_eq!(vals.len(), 5, "should have exactly 5 values [0.1…0.5]");
303 assert!(
304 (vals[4].as_float() - 0.5).abs() < 1e-12,
305 "endpoint must be exactly 0.5"
306 );
307
308 let vals2 = ParamRange::float_range(0.1, 0.5, 0.15).expand();
310 assert_eq!(vals2.len(), 4);
311 assert!((vals2[3].as_float() - 0.5).abs() < 1e-12);
312 }
313
314 #[test]
315 fn test_float_bounds_expand_returns_empty() {
316 let r = ParamRange::float_bounds(0.1, 0.9);
318 assert!(r.expand().is_empty());
319 }
320
321 #[test]
324 fn test_int_bounds_sample_at() {
325 let r = ParamRange::int_bounds(5, 50);
326 assert_eq!(r.sample_at(0.0), ParamValue::Int(5));
327 assert_eq!(r.sample_at(1.0), ParamValue::Int(50));
328 assert!(matches!(r.sample_at(0.5), ParamValue::Int(_)));
329 }
330
331 #[test]
332 fn test_float_bounds_sample_at() {
333 let r = ParamRange::float_bounds(0.3, 0.7);
334 assert!((r.sample_at(0.0).as_float() - 0.3).abs() < 1e-12);
335 assert!((r.sample_at(1.0).as_float() - 0.7).abs() < 1e-12);
336 assert!((r.sample_at(0.5).as_float() - 0.5).abs() < 1e-12);
337 assert!(matches!(r.sample_at(0.5), ParamValue::Float(_)));
338 }
339
340 #[test]
341 fn test_sample_at_int_range() {
342 let r = ParamRange::int_bounds(0, 9);
343 assert_eq!(r.sample_at(0.0), ParamValue::Int(0));
344 assert_eq!(r.sample_at(1.0), ParamValue::Int(9));
345 assert_eq!(r.sample_at(0.5), ParamValue::Int(5));
346 }
347
348 #[test]
349 fn test_sample_at_values_range() {
350 let r = ParamRange::Values(vec![
351 ParamValue::Int(10),
352 ParamValue::Int(20),
353 ParamValue::Int(30),
354 ]);
355 assert_eq!(r.sample_at(0.0), ParamValue::Int(10));
356 assert_eq!(r.sample_at(1.0), ParamValue::Int(30));
357 assert_eq!(r.sample_at(0.5), ParamValue::Int(20));
358 }
359
360 #[test]
363 fn test_cartesian_product() {
364 let params: Vec<(&str, Vec<ParamValue>)> = vec![
365 ("a", vec![ParamValue::Int(1), ParamValue::Int(2)]),
366 ("b", vec![ParamValue::Int(10), ParamValue::Int(20)]),
367 ];
368 let combos = cartesian_product(¶ms);
369 assert_eq!(combos.len(), 4);
370 }
371
372 #[test]
375 fn test_grid_search_runs() {
376 let prices = trending_prices(100);
377 let candles = make_candles(&prices);
378 let config = BacktestConfig::builder()
379 .commission_pct(0.0)
380 .slippage_pct(0.0)
381 .build()
382 .unwrap();
383
384 let report = GridSearch::new()
385 .param("fast", ParamRange::int_range(3, 10, 3))
386 .param("slow", ParamRange::int_range(10, 20, 10))
387 .optimize_for(OptimizeMetric::TotalReturn)
388 .run("TEST", &candles, &config, |params| {
389 SmaCrossover::new(
390 params["fast"].as_int() as usize,
391 params["slow"].as_int() as usize,
392 )
393 })
394 .unwrap();
395
396 assert!(!report.results.is_empty());
397 assert_eq!(report.strategy_name, "SMA Crossover");
398 assert!(
399 report.convergence_curve.is_empty(),
400 "GridSearch curve should be empty"
401 );
402 assert_eq!(report.n_evaluations, report.total_combinations);
403
404 if report.results.len() > 1 {
405 let first = OptimizeMetric::TotalReturn.score(&report.results[0].result);
406 let second = OptimizeMetric::TotalReturn.score(&report.results[1].result);
407 assert!(first >= second);
408 }
409 }
410
411 #[test]
412 fn test_grid_search_no_params_errors() {
413 let candles = make_candles(&trending_prices(50));
414 let config = BacktestConfig::default();
415 let result = GridSearch::new().run("TEST", &candles, &config, |_| SmaCrossover::new(5, 10));
416 assert!(result.is_err());
417 }
418
419 #[test]
420 fn test_grid_search_float_bounds_errors() {
421 let candles = make_candles(&trending_prices(100));
423 let config = BacktestConfig::default();
424 let result = GridSearch::new()
425 .param("x", ParamRange::float_bounds(0.1, 0.9))
426 .run("TEST", &candles, &config, |_| SmaCrossover::new(5, 20));
427 assert!(result.is_err());
428 }
429
430 #[test]
431 fn test_optimize_metric_min_drawdown() {
432 let prices = trending_prices(60);
433 let candles = make_candles(&prices);
434 let config = BacktestConfig::builder()
435 .commission_pct(0.0)
436 .slippage_pct(0.0)
437 .build()
438 .unwrap();
439
440 let report = GridSearch::new()
441 .param("fast", ParamRange::int_range(3, 9, 3))
442 .param("slow", ParamRange::int_range(10, 20, 10))
443 .optimize_for(OptimizeMetric::MinDrawdown)
444 .run("TEST", &candles, &config, |params| {
445 SmaCrossover::new(
446 params["fast"].as_int() as usize,
447 params["slow"].as_int() as usize,
448 )
449 })
450 .unwrap();
451
452 assert!(!report.results.is_empty());
453 if report.results.len() > 1 {
454 let first = report.results[0].result.metrics.max_drawdown_pct;
455 let second = report.results[1].result.metrics.max_drawdown_pct;
456 assert!(first <= second + 1e-9, "best has smallest drawdown");
457 }
458 }
459}