Skip to main content

finance_query/backtesting/optimizer/
grid.rs

1//! Exhaustive grid-search parameter optimisation.
2//!
3//! Use [`GridSearch`] to sweep over all combinations of named parameter ranges
4//! and rank them by a chosen metric. All combinations run in parallel via
5//! `rayon`, and results are returned sorted best-first.
6//!
7//! # Example
8//!
9//! ```ignore
10//! use finance_query::backtesting::{
11//!     BacktestConfig, SmaCrossover,
12//!     optimizer::{GridSearch, OptimizeMetric, ParamRange, ParamValue},
13//! };
14//!
15//! # fn example(candles: &[finance_query::models::chart::Candle]) {
16//! let report = GridSearch::new()
17//!     .param("fast", ParamRange::int_range(5, 50, 5))
18//!     .param("slow", ParamRange::int_range(20, 200, 10))
19//!     .optimize_for(OptimizeMetric::SharpeRatio)
20//!     .run("AAPL", candles, &BacktestConfig::default(), |params| {
21//!         SmaCrossover::new(
22//!             params["fast"].as_int() as usize,
23//!             params["slow"].as_int() as usize,
24//!         )
25//!     })
26//!     .unwrap();
27//!
28//! println!("Best params: {:?}", report.best.params);
29//! println!("Best Sharpe: {:.2}", report.best.result.metrics.sharpe_ratio);
30//! # }
31//! ```
32
33use std::collections::HashMap;
34use std::sync::atomic::{AtomicUsize, Ordering};
35
36use rayon::prelude::*;
37
38use crate::models::chart::Candle;
39
40use super::super::config::BacktestConfig;
41use super::super::engine::BacktestEngine;
42use super::super::error::{BacktestError, Result};
43use super::super::strategy::Strategy;
44use super::{
45    OptimizationReport, OptimizationResult, OptimizeMetric, ParamRange, ParamValue,
46    sort_results_best_first,
47};
48
49// ── GridSearch ────────────────────────────────────────────────────────────────
50
51/// Exhaustive grid-search optimiser for backtesting strategy parameters.
52///
53/// Evaluates every combination of the supplied parameter ranges in parallel.
54/// Use [`BayesianSearch`] instead when the cartesian product would exceed
55/// a few thousand combinations or when float ranges without a step are needed.
56///
57/// # Overfitting Warning
58///
59/// Results are **in-sample only**. Follow up with [`WalkForwardConfig`] or a
60/// held-out test window to obtain an unbiased out-of-sample estimate.
61///
62/// [`BayesianSearch`]: super::BayesianSearch
63/// [`WalkForwardConfig`]: super::super::walk_forward::WalkForwardConfig
64#[derive(Debug, Clone, Default)]
65pub struct GridSearch {
66    /// Named parameter ranges, in insertion order (for reproducibility).
67    params: Vec<(String, ParamRange)>,
68    /// Metric to maximise (defaults to `SharpeRatio`).
69    metric: Option<OptimizeMetric>,
70}
71
72impl GridSearch {
73    /// Create a new grid search with no parameters defined yet.
74    pub fn new() -> Self {
75        Self::default()
76    }
77
78    /// Add a named parameter range to sweep.
79    ///
80    /// Parameters are expanded in cartesian-product order: the last parameter
81    /// added cycles fastest (inner loop).
82    pub fn param(mut self, name: impl Into<String>, range: ParamRange) -> Self {
83        self.params.push((name.into(), range));
84        self
85    }
86
87    /// Set the metric to optimise for (defaults to [`OptimizeMetric::SharpeRatio`]).
88    pub fn optimize_for(mut self, metric: OptimizeMetric) -> Self {
89        self.metric = Some(metric);
90        self
91    }
92
93    /// Run the grid search.
94    ///
95    /// `symbol` is used only for labelling in the returned results.
96    ///
97    /// `factory` receives the current parameter map and returns a strategy
98    /// instance. Combinations that exceed the strategy's warmup period are
99    /// silently skipped.
100    ///
101    /// Returns an error when the grid is empty or all combinations were skipped.
102    pub fn run<S, F>(
103        &self,
104        symbol: &str,
105        candles: &[Candle],
106        config: &BacktestConfig,
107        factory: F,
108    ) -> Result<OptimizationReport>
109    where
110        S: Strategy + Send,
111        F: Fn(&HashMap<String, ParamValue>) -> S + Send + Sync,
112    {
113        if self.params.is_empty() {
114            return Err(BacktestError::invalid_param(
115                "params",
116                "grid search requires at least one parameter range",
117            ));
118        }
119
120        let metric = self.metric.unwrap_or(OptimizeMetric::SharpeRatio);
121
122        let expanded: Vec<(&str, Vec<ParamValue>)> = self
123            .params
124            .iter()
125            .map(|(name, range)| (name.as_str(), range.expand()))
126            .collect();
127
128        let combinations = cartesian_product(&expanded);
129        let total_combinations = combinations.len();
130
131        if total_combinations == 0 {
132            return Err(BacktestError::invalid_param(
133                "params",
134                "all parameter ranges produced empty value sets \
135                 (hint: float_bounds is not compatible with GridSearch — use BayesianSearch)",
136            ));
137        }
138
139        if total_combinations > 10_000 {
140            tracing::warn!(
141                total_combinations,
142                "grid search: large combination count — consider BayesianSearch or wider steps"
143            );
144        }
145
146        let skipped_errors = AtomicUsize::new(0);
147        let mut results: Vec<OptimizationResult> = combinations
148            .into_par_iter()
149            .filter_map(|params| {
150                let strategy = factory(&params);
151                match BacktestEngine::new(config.clone()).run(symbol, candles, strategy) {
152                    Ok(result) => Some(OptimizationResult { params, result }),
153                    Err(BacktestError::InsufficientData { .. }) => None,
154                    Err(e) => {
155                        tracing::warn!(
156                            params = ?params,
157                            error = %e,
158                            "grid search: skipping combination due to unexpected error"
159                        );
160                        skipped_errors.fetch_add(1, Ordering::Relaxed);
161                        None
162                    }
163                }
164            })
165            .collect();
166        let skipped_errors = skipped_errors.into_inner();
167
168        if results.is_empty() {
169            return Err(BacktestError::invalid_param(
170                "candles",
171                "no parameter combination had enough data to run",
172            ));
173        }
174
175        sort_results_best_first(&mut results, metric);
176
177        if metric.score(&results[0].result).is_nan() {
178            return Err(BacktestError::invalid_param(
179                "metric",
180                "all parameter combinations produced NaN for the target metric",
181            ));
182        }
183
184        let strategy_name = results[0].result.strategy_name.clone();
185        let best = results[0].clone();
186        let n_evaluations = total_combinations;
187
188        Ok(OptimizationReport {
189            strategy_name,
190            total_combinations,
191            results,
192            best,
193            skipped_errors,
194            // GridSearch runs all combinations in parallel — no sequential ordering,
195            // so the convergence curve is meaningless and left empty.
196            convergence_curve: vec![],
197            n_evaluations,
198        })
199    }
200}
201
202// ── Internal helpers ──────────────────────────────────────────────────────────
203
204/// Compute the cartesian product of named parameter value lists.
205///
206/// Returns a `Vec` of `HashMap`s, one per combination. The last parameter
207/// cycles fastest (inner loop).
208fn cartesian_product(params: &[(&str, Vec<ParamValue>)]) -> Vec<HashMap<String, ParamValue>> {
209    if params.is_empty() {
210        return vec![];
211    }
212
213    let mut result: Vec<HashMap<String, ParamValue>> = vec![HashMap::new()];
214
215    for (name, values) in params {
216        let mut next = Vec::with_capacity(result.len() * values.len());
217        for existing in &result {
218            for value in values {
219                let mut combo = existing.clone();
220                combo.insert(name.to_string(), value.clone());
221                next.push(combo);
222            }
223        }
224        result = next;
225    }
226
227    result
228}
229
230// ── Tests ─────────────────────────────────────────────────────────────────────
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use crate::backtesting::{BacktestConfig, SmaCrossover};
236    use crate::models::chart::Candle;
237
238    fn make_candles(prices: &[f64]) -> Vec<Candle> {
239        prices
240            .iter()
241            .enumerate()
242            .map(|(i, &p)| Candle {
243                timestamp: i as i64,
244                open: p,
245                high: p * 1.01,
246                low: p * 0.99,
247                close: p,
248                volume: 1000,
249                adj_close: Some(p),
250            })
251            .collect()
252    }
253
254    fn trending_prices(n: usize) -> Vec<f64> {
255        (0..n).map(|i| 100.0 + i as f64 * 0.5).collect()
256    }
257
258    // ── ParamValue ────────────────────────────────────────────────────────────
259
260    #[test]
261    fn test_param_value_conversion() {
262        let iv = ParamValue::Int(10);
263        assert_eq!(iv.as_int(), 10);
264        assert!((iv.as_float() - 10.0).abs() < f64::EPSILON);
265
266        let fv = ParamValue::Float(1.5);
267        assert_eq!(fv.as_int(), 1);
268        assert!((fv.as_float() - 1.5).abs() < f64::EPSILON);
269    }
270
271    // ── ParamRange expansion (grid path) ─────────────────────────────────────
272
273    #[test]
274    fn test_int_range_expand() {
275        let r = ParamRange::int_range(5, 20, 5);
276        let vals = r.expand();
277        assert_eq!(
278            vals,
279            vec![
280                ParamValue::Int(5),
281                ParamValue::Int(10),
282                ParamValue::Int(15),
283                ParamValue::Int(20),
284            ]
285        );
286    }
287
288    #[test]
289    fn test_float_range_expand() {
290        let r = ParamRange::float_range(0.1, 0.3, 0.1);
291        let vals = r.expand();
292        assert_eq!(vals.len(), 3);
293        assert!((vals[0].as_float() - 0.1).abs() < 1e-9);
294        assert!((vals[2].as_float() - 0.3).abs() < 1e-9);
295    }
296
297    /// Floating-point arithmetic can produce `start + N * step` slightly above
298    /// `end`. The endpoint must be clamped exactly to `end` with no extra values.
299    #[test]
300    fn test_float_range_endpoint_clamping() {
301        let vals = ParamRange::float_range(0.1, 0.5, 0.1).expand();
302        assert_eq!(vals.len(), 5, "should have exactly 5 values [0.1…0.5]");
303        assert!(
304            (vals[4].as_float() - 0.5).abs() < 1e-12,
305            "endpoint must be exactly 0.5"
306        );
307
308        // step that doesn't evenly divide the range
309        let vals2 = ParamRange::float_range(0.1, 0.5, 0.15).expand();
310        assert_eq!(vals2.len(), 4);
311        assert!((vals2[3].as_float() - 0.5).abs() < 1e-12);
312    }
313
314    #[test]
315    fn test_float_bounds_expand_returns_empty() {
316        // float_bounds has step=0.0, which is intentionally invalid for GridSearch.
317        let r = ParamRange::float_bounds(0.1, 0.9);
318        assert!(r.expand().is_empty());
319    }
320
321    // ── ParamRange sampling (Bayesian path) ───────────────────────────────────
322
323    #[test]
324    fn test_int_bounds_sample_at() {
325        let r = ParamRange::int_bounds(5, 50);
326        assert_eq!(r.sample_at(0.0), ParamValue::Int(5));
327        assert_eq!(r.sample_at(1.0), ParamValue::Int(50));
328        assert!(matches!(r.sample_at(0.5), ParamValue::Int(_)));
329    }
330
331    #[test]
332    fn test_float_bounds_sample_at() {
333        let r = ParamRange::float_bounds(0.3, 0.7);
334        assert!((r.sample_at(0.0).as_float() - 0.3).abs() < 1e-12);
335        assert!((r.sample_at(1.0).as_float() - 0.7).abs() < 1e-12);
336        assert!((r.sample_at(0.5).as_float() - 0.5).abs() < 1e-12);
337        assert!(matches!(r.sample_at(0.5), ParamValue::Float(_)));
338    }
339
340    #[test]
341    fn test_sample_at_int_range() {
342        let r = ParamRange::int_bounds(0, 9);
343        assert_eq!(r.sample_at(0.0), ParamValue::Int(0));
344        assert_eq!(r.sample_at(1.0), ParamValue::Int(9));
345        assert_eq!(r.sample_at(0.5), ParamValue::Int(5));
346    }
347
348    #[test]
349    fn test_sample_at_values_range() {
350        let r = ParamRange::Values(vec![
351            ParamValue::Int(10),
352            ParamValue::Int(20),
353            ParamValue::Int(30),
354        ]);
355        assert_eq!(r.sample_at(0.0), ParamValue::Int(10));
356        assert_eq!(r.sample_at(1.0), ParamValue::Int(30));
357        assert_eq!(r.sample_at(0.5), ParamValue::Int(20));
358    }
359
360    // ── cartesian_product ─────────────────────────────────────────────────────
361
362    #[test]
363    fn test_cartesian_product() {
364        let params: Vec<(&str, Vec<ParamValue>)> = vec![
365            ("a", vec![ParamValue::Int(1), ParamValue::Int(2)]),
366            ("b", vec![ParamValue::Int(10), ParamValue::Int(20)]),
367        ];
368        let combos = cartesian_product(&params);
369        assert_eq!(combos.len(), 4);
370    }
371
372    // ── GridSearch integration ────────────────────────────────────────────────
373
374    #[test]
375    fn test_grid_search_runs() {
376        let prices = trending_prices(100);
377        let candles = make_candles(&prices);
378        let config = BacktestConfig::builder()
379            .commission_pct(0.0)
380            .slippage_pct(0.0)
381            .build()
382            .unwrap();
383
384        let report = GridSearch::new()
385            .param("fast", ParamRange::int_range(3, 10, 3))
386            .param("slow", ParamRange::int_range(10, 20, 10))
387            .optimize_for(OptimizeMetric::TotalReturn)
388            .run("TEST", &candles, &config, |params| {
389                SmaCrossover::new(
390                    params["fast"].as_int() as usize,
391                    params["slow"].as_int() as usize,
392                )
393            })
394            .unwrap();
395
396        assert!(!report.results.is_empty());
397        assert_eq!(report.strategy_name, "SMA Crossover");
398        assert!(
399            report.convergence_curve.is_empty(),
400            "GridSearch curve should be empty"
401        );
402        assert_eq!(report.n_evaluations, report.total_combinations);
403
404        if report.results.len() > 1 {
405            let first = OptimizeMetric::TotalReturn.score(&report.results[0].result);
406            let second = OptimizeMetric::TotalReturn.score(&report.results[1].result);
407            assert!(first >= second);
408        }
409    }
410
411    #[test]
412    fn test_grid_search_no_params_errors() {
413        let candles = make_candles(&trending_prices(50));
414        let config = BacktestConfig::default();
415        let result = GridSearch::new().run("TEST", &candles, &config, |_| SmaCrossover::new(5, 10));
416        assert!(result.is_err());
417    }
418
419    #[test]
420    fn test_grid_search_float_bounds_errors() {
421        // float_bounds is incompatible with GridSearch (step=0.0 → empty expansion).
422        let candles = make_candles(&trending_prices(100));
423        let config = BacktestConfig::default();
424        let result = GridSearch::new()
425            .param("x", ParamRange::float_bounds(0.1, 0.9))
426            .run("TEST", &candles, &config, |_| SmaCrossover::new(5, 20));
427        assert!(result.is_err());
428    }
429
430    #[test]
431    fn test_optimize_metric_min_drawdown() {
432        let prices = trending_prices(60);
433        let candles = make_candles(&prices);
434        let config = BacktestConfig::builder()
435            .commission_pct(0.0)
436            .slippage_pct(0.0)
437            .build()
438            .unwrap();
439
440        let report = GridSearch::new()
441            .param("fast", ParamRange::int_range(3, 9, 3))
442            .param("slow", ParamRange::int_range(10, 20, 10))
443            .optimize_for(OptimizeMetric::MinDrawdown)
444            .run("TEST", &candles, &config, |params| {
445                SmaCrossover::new(
446                    params["fast"].as_int() as usize,
447                    params["slow"].as_int() as usize,
448                )
449            })
450            .unwrap();
451
452        assert!(!report.results.is_empty());
453        if report.results.len() > 1 {
454            let first = report.results[0].result.metrics.max_drawdown_pct;
455            let second = report.results[1].result.metrics.max_drawdown_pct;
456            assert!(first <= second + 1e-9, "best has smallest drawdown");
457        }
458    }
459}