finance_query/backtesting/optimizer/
grid.rs1use std::collections::HashMap;
34use std::sync::atomic::{AtomicUsize, Ordering};
35
36use rayon::prelude::*;
37
38use crate::models::chart::Candle;
39
40use super::super::config::BacktestConfig;
41use super::super::engine::BacktestEngine;
42use super::super::error::{BacktestError, Result};
43use super::super::strategy::Strategy;
44use super::{
45 OptimizationReport, OptimizationResult, OptimizeMetric, ParamRange, ParamValue,
46 sort_results_best_first,
47};
48
49#[derive(Debug, Clone, Default)]
65pub struct GridSearch {
66 params: Vec<(String, ParamRange)>,
68 metric: Option<OptimizeMetric>,
70}
71
72impl GridSearch {
73 pub fn new() -> Self {
75 Self::default()
76 }
77
78 pub fn param(mut self, name: impl Into<String>, range: ParamRange) -> Self {
83 self.params.push((name.into(), range));
84 self
85 }
86
87 pub fn optimize_for(mut self, metric: OptimizeMetric) -> Self {
89 self.metric = Some(metric);
90 self
91 }
92
93 pub fn run<S, F>(
103 &self,
104 symbol: &str,
105 candles: &[Candle],
106 config: &BacktestConfig,
107 factory: F,
108 ) -> Result<OptimizationReport>
109 where
110 S: Strategy + Send,
111 F: Fn(&HashMap<String, ParamValue>) -> S + Send + Sync,
112 {
113 if self.params.is_empty() {
114 return Err(BacktestError::invalid_param(
115 "params",
116 "grid search requires at least one parameter range",
117 ));
118 }
119
120 let metric = self.metric.unwrap_or(OptimizeMetric::SharpeRatio);
121
122 let expanded: Vec<(&str, Vec<ParamValue>)> = self
123 .params
124 .iter()
125 .map(|(name, range)| (name.as_str(), range.expand()))
126 .collect();
127
128 let combinations = cartesian_product(&expanded);
129 let total_combinations = combinations.len();
130
131 if total_combinations == 0 {
132 return Err(BacktestError::invalid_param(
133 "params",
134 "all parameter ranges produced empty value sets \
135 (hint: float_bounds is not compatible with GridSearch — use BayesianSearch)",
136 ));
137 }
138
139 if total_combinations > 10_000 {
140 tracing::warn!(
141 total_combinations,
142 "grid search: large combination count — consider BayesianSearch or wider steps"
143 );
144 }
145
146 let skipped_errors = AtomicUsize::new(0);
147 let mut results: Vec<OptimizationResult> = combinations
148 .into_par_iter()
149 .filter_map(|params| {
150 let strategy = factory(¶ms);
151 match BacktestEngine::new(config.clone()).run(symbol, candles, strategy) {
152 Ok(result) => Some(OptimizationResult { params, result }),
153 Err(BacktestError::InsufficientData { .. }) => None,
154 Err(e) => {
155 tracing::warn!(
156 params = ?params,
157 error = %e,
158 "grid search: skipping combination due to unexpected error"
159 );
160 skipped_errors.fetch_add(1, Ordering::Relaxed);
161 None
162 }
163 }
164 })
165 .collect();
166 let skipped_errors = skipped_errors.into_inner();
167
168 if results.is_empty() {
169 return Err(BacktestError::invalid_param(
170 "candles",
171 "no parameter combination had enough data to run",
172 ));
173 }
174
175 sort_results_best_first(&mut results, metric);
176
177 if metric.score(&results[0].result).is_nan() {
178 return Err(BacktestError::invalid_param(
179 "metric",
180 "all parameter combinations produced NaN for the target metric",
181 ));
182 }
183
184 let strategy_name = results[0].result.strategy_name.clone();
185 let best = results[0].clone();
186 let n_evaluations = total_combinations;
187
188 Ok(OptimizationReport {
189 strategy_name,
190 total_combinations,
191 results,
192 best,
193 skipped_errors,
194 convergence_curve: vec![],
197 n_evaluations,
198 })
199 }
200}
201
202fn cartesian_product(params: &[(&str, Vec<ParamValue>)]) -> Vec<HashMap<String, ParamValue>> {
209 if params.is_empty() {
210 return vec![];
211 }
212
213 let mut result: Vec<HashMap<String, ParamValue>> = vec![HashMap::new()];
214
215 for (name, values) in params {
216 let mut next = Vec::with_capacity(result.len() * values.len());
217 for existing in &result {
218 for value in values {
219 let mut combo = existing.clone();
220 combo.insert(name.to_string(), value.clone());
221 next.push(combo);
222 }
223 }
224 result = next;
225 }
226
227 result
228}
229
230#[cfg(test)]
233mod tests {
234 use super::*;
235 use crate::backtesting::{BacktestConfig, SmaCrossover};
236 use crate::models::chart::Candle;
237
238 fn make_candles(prices: &[f64]) -> Vec<Candle> {
239 prices
240 .iter()
241 .enumerate()
242 .map(|(i, &p)| Candle {
243 timestamp: i as i64,
244 open: p,
245 high: p * 1.01,
246 low: p * 0.99,
247 close: p,
248 volume: 1000,
249 adj_close: Some(p),
250 provider_id: None,
251 })
252 .collect()
253 }
254
255 fn trending_prices(n: usize) -> Vec<f64> {
256 (0..n).map(|i| 100.0 + i as f64 * 0.5).collect()
257 }
258
259 #[test]
262 fn test_param_value_conversion() {
263 let iv = ParamValue::Int(10);
264 assert_eq!(iv.as_int(), 10);
265 assert!((iv.as_float() - 10.0).abs() < f64::EPSILON);
266
267 let fv = ParamValue::Float(1.5);
268 assert_eq!(fv.as_int(), 1);
269 assert!((fv.as_float() - 1.5).abs() < f64::EPSILON);
270 }
271
272 #[test]
275 fn test_int_range_expand() {
276 let r = ParamRange::int_range(5, 20, 5);
277 let vals = r.expand();
278 assert_eq!(
279 vals,
280 vec![
281 ParamValue::Int(5),
282 ParamValue::Int(10),
283 ParamValue::Int(15),
284 ParamValue::Int(20),
285 ]
286 );
287 }
288
289 #[test]
290 fn test_float_range_expand() {
291 let r = ParamRange::float_range(0.1, 0.3, 0.1);
292 let vals = r.expand();
293 assert_eq!(vals.len(), 3);
294 assert!((vals[0].as_float() - 0.1).abs() < 1e-9);
295 assert!((vals[2].as_float() - 0.3).abs() < 1e-9);
296 }
297
298 #[test]
301 fn test_float_range_endpoint_clamping() {
302 let vals = ParamRange::float_range(0.1, 0.5, 0.1).expand();
303 assert_eq!(vals.len(), 5, "should have exactly 5 values [0.1…0.5]");
304 assert!(
305 (vals[4].as_float() - 0.5).abs() < 1e-12,
306 "endpoint must be exactly 0.5"
307 );
308
309 let vals2 = ParamRange::float_range(0.1, 0.5, 0.15).expand();
311 assert_eq!(vals2.len(), 4);
312 assert!((vals2[3].as_float() - 0.5).abs() < 1e-12);
313 }
314
315 #[test]
316 fn test_float_bounds_expand_returns_empty() {
317 let r = ParamRange::float_bounds(0.1, 0.9);
319 assert!(r.expand().is_empty());
320 }
321
322 #[test]
325 fn test_int_bounds_sample_at() {
326 let r = ParamRange::int_bounds(5, 50);
327 assert_eq!(r.sample_at(0.0), ParamValue::Int(5));
328 assert_eq!(r.sample_at(1.0), ParamValue::Int(50));
329 assert!(matches!(r.sample_at(0.5), ParamValue::Int(_)));
330 }
331
332 #[test]
333 fn test_float_bounds_sample_at() {
334 let r = ParamRange::float_bounds(0.3, 0.7);
335 assert!((r.sample_at(0.0).as_float() - 0.3).abs() < 1e-12);
336 assert!((r.sample_at(1.0).as_float() - 0.7).abs() < 1e-12);
337 assert!((r.sample_at(0.5).as_float() - 0.5).abs() < 1e-12);
338 assert!(matches!(r.sample_at(0.5), ParamValue::Float(_)));
339 }
340
341 #[test]
342 fn test_sample_at_int_range() {
343 let r = ParamRange::int_bounds(0, 9);
344 assert_eq!(r.sample_at(0.0), ParamValue::Int(0));
345 assert_eq!(r.sample_at(1.0), ParamValue::Int(9));
346 assert_eq!(r.sample_at(0.5), ParamValue::Int(5));
347 }
348
349 #[test]
350 fn test_sample_at_values_range() {
351 let r = ParamRange::Values(vec![
352 ParamValue::Int(10),
353 ParamValue::Int(20),
354 ParamValue::Int(30),
355 ]);
356 assert_eq!(r.sample_at(0.0), ParamValue::Int(10));
357 assert_eq!(r.sample_at(1.0), ParamValue::Int(30));
358 assert_eq!(r.sample_at(0.5), ParamValue::Int(20));
359 }
360
361 #[test]
364 fn test_cartesian_product() {
365 let params: Vec<(&str, Vec<ParamValue>)> = vec![
366 ("a", vec![ParamValue::Int(1), ParamValue::Int(2)]),
367 ("b", vec![ParamValue::Int(10), ParamValue::Int(20)]),
368 ];
369 let combos = cartesian_product(¶ms);
370 assert_eq!(combos.len(), 4);
371 }
372
373 #[test]
376 fn test_grid_search_runs() {
377 let prices = trending_prices(100);
378 let candles = make_candles(&prices);
379 let config = BacktestConfig::builder()
380 .commission_pct(0.0)
381 .slippage_pct(0.0)
382 .build()
383 .unwrap();
384
385 let report = GridSearch::new()
386 .param("fast", ParamRange::int_range(3, 10, 3))
387 .param("slow", ParamRange::int_range(10, 20, 10))
388 .optimize_for(OptimizeMetric::TotalReturn)
389 .run("TEST", &candles, &config, |params| {
390 SmaCrossover::new(
391 params["fast"].as_int() as usize,
392 params["slow"].as_int() as usize,
393 )
394 })
395 .unwrap();
396
397 assert!(!report.results.is_empty());
398 assert_eq!(report.strategy_name, "SMA Crossover");
399 assert!(
400 report.convergence_curve.is_empty(),
401 "GridSearch curve should be empty"
402 );
403 assert_eq!(report.n_evaluations, report.total_combinations);
404
405 if report.results.len() > 1 {
406 let first = OptimizeMetric::TotalReturn.score(&report.results[0].result);
407 let second = OptimizeMetric::TotalReturn.score(&report.results[1].result);
408 assert!(first >= second);
409 }
410 }
411
412 #[test]
413 fn test_grid_search_no_params_errors() {
414 let candles = make_candles(&trending_prices(50));
415 let config = BacktestConfig::default();
416 let result = GridSearch::new().run("TEST", &candles, &config, |_| SmaCrossover::new(5, 10));
417 assert!(result.is_err());
418 }
419
420 #[test]
421 fn test_grid_search_float_bounds_errors() {
422 let candles = make_candles(&trending_prices(100));
424 let config = BacktestConfig::default();
425 let result = GridSearch::new()
426 .param("x", ParamRange::float_bounds(0.1, 0.9))
427 .run("TEST", &candles, &config, |_| SmaCrossover::new(5, 20));
428 assert!(result.is_err());
429 }
430
431 #[test]
432 fn test_optimize_metric_min_drawdown() {
433 let prices = trending_prices(60);
434 let candles = make_candles(&prices);
435 let config = BacktestConfig::builder()
436 .commission_pct(0.0)
437 .slippage_pct(0.0)
438 .build()
439 .unwrap();
440
441 let report = GridSearch::new()
442 .param("fast", ParamRange::int_range(3, 9, 3))
443 .param("slow", ParamRange::int_range(10, 20, 10))
444 .optimize_for(OptimizeMetric::MinDrawdown)
445 .run("TEST", &candles, &config, |params| {
446 SmaCrossover::new(
447 params["fast"].as_int() as usize,
448 params["slow"].as_int() as usize,
449 )
450 })
451 .unwrap();
452
453 assert!(!report.results.is_empty());
454 if report.results.len() > 1 {
455 let first = report.results[0].result.metrics.max_drawdown_pct;
456 let second = report.results[1].result.metrics.max_drawdown_pct;
457 assert!(first <= second + 1e-9, "best has smallest drawdown");
458 }
459 }
460}