Skip to main content

finance_query/backtesting/optimizer/
grid.rs

1//! Exhaustive grid-search parameter optimisation.
2//!
3//! Use [`GridSearch`] to sweep over all combinations of named parameter ranges
4//! and rank them by a chosen metric. All combinations run in parallel via
5//! `rayon`, and results are returned sorted best-first.
6//!
7//! # Example
8//!
9//! ```ignore
10//! use finance_query::backtesting::{
11//!     BacktestConfig, SmaCrossover,
12//!     optimizer::{GridSearch, OptimizeMetric, ParamRange, ParamValue},
13//! };
14//!
15//! # fn example(candles: &[finance_query::models::chart::Candle]) {
16//! let report = GridSearch::new()
17//!     .param("fast", ParamRange::int_range(5, 50, 5))
18//!     .param("slow", ParamRange::int_range(20, 200, 10))
19//!     .optimize_for(OptimizeMetric::SharpeRatio)
20//!     .run("AAPL", candles, &BacktestConfig::default(), |params| {
21//!         SmaCrossover::new(
22//!             params["fast"].as_int() as usize,
23//!             params["slow"].as_int() as usize,
24//!         )
25//!     })
26//!     .unwrap();
27//!
28//! println!("Best params: {:?}", report.best.params);
29//! println!("Best Sharpe: {:.2}", report.best.result.metrics.sharpe_ratio);
30//! # }
31//! ```
32
33use std::collections::HashMap;
34use std::sync::atomic::{AtomicUsize, Ordering};
35
36use rayon::prelude::*;
37
38use crate::models::chart::Candle;
39
40use super::super::config::BacktestConfig;
41use super::super::engine::BacktestEngine;
42use super::super::error::{BacktestError, Result};
43use super::super::strategy::Strategy;
44use super::{
45    OptimizationReport, OptimizationResult, OptimizeMetric, ParamRange, ParamValue,
46    sort_results_best_first,
47};
48
49// ── GridSearch ────────────────────────────────────────────────────────────────
50
51/// Exhaustive grid-search optimiser for backtesting strategy parameters.
52///
53/// Evaluates every combination of the supplied parameter ranges in parallel.
54/// Use [`BayesianSearch`] instead when the cartesian product would exceed
55/// a few thousand combinations or when float ranges without a step are needed.
56///
57/// # Overfitting Warning
58///
59/// Results are **in-sample only**. Follow up with [`WalkForwardConfig`] or a
60/// held-out test window to obtain an unbiased out-of-sample estimate.
61///
62/// [`BayesianSearch`]: super::BayesianSearch
63/// [`WalkForwardConfig`]: super::super::walk_forward::WalkForwardConfig
64#[derive(Debug, Clone, Default)]
65pub struct GridSearch {
66    /// Named parameter ranges, in insertion order (for reproducibility).
67    params: Vec<(String, ParamRange)>,
68    /// Metric to maximise (defaults to `SharpeRatio`).
69    metric: Option<OptimizeMetric>,
70}
71
72impl GridSearch {
73    /// Create a new grid search with no parameters defined yet.
74    pub fn new() -> Self {
75        Self::default()
76    }
77
78    /// Add a named parameter range to sweep.
79    ///
80    /// Parameters are expanded in cartesian-product order: the last parameter
81    /// added cycles fastest (inner loop).
82    pub fn param(mut self, name: impl Into<String>, range: ParamRange) -> Self {
83        self.params.push((name.into(), range));
84        self
85    }
86
87    /// Set the metric to optimise for (defaults to [`OptimizeMetric::SharpeRatio`]).
88    pub fn optimize_for(mut self, metric: OptimizeMetric) -> Self {
89        self.metric = Some(metric);
90        self
91    }
92
93    /// Run the grid search.
94    ///
95    /// `symbol` is used only for labelling in the returned results.
96    ///
97    /// `factory` receives the current parameter map and returns a strategy
98    /// instance. Combinations that exceed the strategy's warmup period are
99    /// silently skipped.
100    ///
101    /// Returns an error when the grid is empty or all combinations were skipped.
102    pub fn run<S, F>(
103        &self,
104        symbol: &str,
105        candles: &[Candle],
106        config: &BacktestConfig,
107        factory: F,
108    ) -> Result<OptimizationReport>
109    where
110        S: Strategy + Send,
111        F: Fn(&HashMap<String, ParamValue>) -> S + Send + Sync,
112    {
113        if self.params.is_empty() {
114            return Err(BacktestError::invalid_param(
115                "params",
116                "grid search requires at least one parameter range",
117            ));
118        }
119
120        let metric = self.metric.unwrap_or(OptimizeMetric::SharpeRatio);
121
122        let expanded: Vec<(&str, Vec<ParamValue>)> = self
123            .params
124            .iter()
125            .map(|(name, range)| (name.as_str(), range.expand()))
126            .collect();
127
128        let combinations = cartesian_product(&expanded);
129        let total_combinations = combinations.len();
130
131        if total_combinations == 0 {
132            return Err(BacktestError::invalid_param(
133                "params",
134                "all parameter ranges produced empty value sets \
135                 (hint: float_bounds is not compatible with GridSearch — use BayesianSearch)",
136            ));
137        }
138
139        if total_combinations > 10_000 {
140            tracing::warn!(
141                total_combinations,
142                "grid search: large combination count — consider BayesianSearch or wider steps"
143            );
144        }
145
146        let skipped_errors = AtomicUsize::new(0);
147        let mut results: Vec<OptimizationResult> = combinations
148            .into_par_iter()
149            .filter_map(|params| {
150                let strategy = factory(&params);
151                match BacktestEngine::new(config.clone()).run(symbol, candles, strategy) {
152                    Ok(result) => Some(OptimizationResult { params, result }),
153                    Err(BacktestError::InsufficientData { .. }) => None,
154                    Err(e) => {
155                        tracing::warn!(
156                            params = ?params,
157                            error = %e,
158                            "grid search: skipping combination due to unexpected error"
159                        );
160                        skipped_errors.fetch_add(1, Ordering::Relaxed);
161                        None
162                    }
163                }
164            })
165            .collect();
166        let skipped_errors = skipped_errors.into_inner();
167
168        if results.is_empty() {
169            return Err(BacktestError::invalid_param(
170                "candles",
171                "no parameter combination had enough data to run",
172            ));
173        }
174
175        sort_results_best_first(&mut results, metric);
176
177        if metric.score(&results[0].result).is_nan() {
178            return Err(BacktestError::invalid_param(
179                "metric",
180                "all parameter combinations produced NaN for the target metric",
181            ));
182        }
183
184        let strategy_name = results[0].result.strategy_name.clone();
185        let best = results[0].clone();
186        let n_evaluations = total_combinations;
187
188        Ok(OptimizationReport {
189            strategy_name,
190            total_combinations,
191            results,
192            best,
193            skipped_errors,
194            // GridSearch runs all combinations in parallel — no sequential ordering,
195            // so the convergence curve is meaningless and left empty.
196            convergence_curve: vec![],
197            n_evaluations,
198        })
199    }
200}
201
202// ── Internal helpers ──────────────────────────────────────────────────────────
203
204/// Compute the cartesian product of named parameter value lists.
205///
206/// Returns a `Vec` of `HashMap`s, one per combination. The last parameter
207/// cycles fastest (inner loop).
208fn cartesian_product(params: &[(&str, Vec<ParamValue>)]) -> Vec<HashMap<String, ParamValue>> {
209    if params.is_empty() {
210        return vec![];
211    }
212
213    let mut result: Vec<HashMap<String, ParamValue>> = vec![HashMap::new()];
214
215    for (name, values) in params {
216        let mut next = Vec::with_capacity(result.len() * values.len());
217        for existing in &result {
218            for value in values {
219                let mut combo = existing.clone();
220                combo.insert(name.to_string(), value.clone());
221                next.push(combo);
222            }
223        }
224        result = next;
225    }
226
227    result
228}
229
230// ── Tests ─────────────────────────────────────────────────────────────────────
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use crate::backtesting::{BacktestConfig, SmaCrossover};
236    use crate::models::chart::Candle;
237
238    fn make_candles(prices: &[f64]) -> Vec<Candle> {
239        prices
240            .iter()
241            .enumerate()
242            .map(|(i, &p)| Candle {
243                timestamp: i as i64,
244                open: p,
245                high: p * 1.01,
246                low: p * 0.99,
247                close: p,
248                volume: 1000,
249                adj_close: Some(p),
250                provider_id: None,
251            })
252            .collect()
253    }
254
255    fn trending_prices(n: usize) -> Vec<f64> {
256        (0..n).map(|i| 100.0 + i as f64 * 0.5).collect()
257    }
258
259    // ── ParamValue ────────────────────────────────────────────────────────────
260
261    #[test]
262    fn test_param_value_conversion() {
263        let iv = ParamValue::Int(10);
264        assert_eq!(iv.as_int(), 10);
265        assert!((iv.as_float() - 10.0).abs() < f64::EPSILON);
266
267        let fv = ParamValue::Float(1.5);
268        assert_eq!(fv.as_int(), 1);
269        assert!((fv.as_float() - 1.5).abs() < f64::EPSILON);
270    }
271
272    // ── ParamRange expansion (grid path) ─────────────────────────────────────
273
274    #[test]
275    fn test_int_range_expand() {
276        let r = ParamRange::int_range(5, 20, 5);
277        let vals = r.expand();
278        assert_eq!(
279            vals,
280            vec![
281                ParamValue::Int(5),
282                ParamValue::Int(10),
283                ParamValue::Int(15),
284                ParamValue::Int(20),
285            ]
286        );
287    }
288
289    #[test]
290    fn test_float_range_expand() {
291        let r = ParamRange::float_range(0.1, 0.3, 0.1);
292        let vals = r.expand();
293        assert_eq!(vals.len(), 3);
294        assert!((vals[0].as_float() - 0.1).abs() < 1e-9);
295        assert!((vals[2].as_float() - 0.3).abs() < 1e-9);
296    }
297
298    /// Floating-point arithmetic can produce `start + N * step` slightly above
299    /// `end`. The endpoint must be clamped exactly to `end` with no extra values.
300    #[test]
301    fn test_float_range_endpoint_clamping() {
302        let vals = ParamRange::float_range(0.1, 0.5, 0.1).expand();
303        assert_eq!(vals.len(), 5, "should have exactly 5 values [0.1…0.5]");
304        assert!(
305            (vals[4].as_float() - 0.5).abs() < 1e-12,
306            "endpoint must be exactly 0.5"
307        );
308
309        // step that doesn't evenly divide the range
310        let vals2 = ParamRange::float_range(0.1, 0.5, 0.15).expand();
311        assert_eq!(vals2.len(), 4);
312        assert!((vals2[3].as_float() - 0.5).abs() < 1e-12);
313    }
314
315    #[test]
316    fn test_float_bounds_expand_returns_empty() {
317        // float_bounds has step=0.0, which is intentionally invalid for GridSearch.
318        let r = ParamRange::float_bounds(0.1, 0.9);
319        assert!(r.expand().is_empty());
320    }
321
322    // ── ParamRange sampling (Bayesian path) ───────────────────────────────────
323
324    #[test]
325    fn test_int_bounds_sample_at() {
326        let r = ParamRange::int_bounds(5, 50);
327        assert_eq!(r.sample_at(0.0), ParamValue::Int(5));
328        assert_eq!(r.sample_at(1.0), ParamValue::Int(50));
329        assert!(matches!(r.sample_at(0.5), ParamValue::Int(_)));
330    }
331
332    #[test]
333    fn test_float_bounds_sample_at() {
334        let r = ParamRange::float_bounds(0.3, 0.7);
335        assert!((r.sample_at(0.0).as_float() - 0.3).abs() < 1e-12);
336        assert!((r.sample_at(1.0).as_float() - 0.7).abs() < 1e-12);
337        assert!((r.sample_at(0.5).as_float() - 0.5).abs() < 1e-12);
338        assert!(matches!(r.sample_at(0.5), ParamValue::Float(_)));
339    }
340
341    #[test]
342    fn test_sample_at_int_range() {
343        let r = ParamRange::int_bounds(0, 9);
344        assert_eq!(r.sample_at(0.0), ParamValue::Int(0));
345        assert_eq!(r.sample_at(1.0), ParamValue::Int(9));
346        assert_eq!(r.sample_at(0.5), ParamValue::Int(5));
347    }
348
349    #[test]
350    fn test_sample_at_values_range() {
351        let r = ParamRange::Values(vec![
352            ParamValue::Int(10),
353            ParamValue::Int(20),
354            ParamValue::Int(30),
355        ]);
356        assert_eq!(r.sample_at(0.0), ParamValue::Int(10));
357        assert_eq!(r.sample_at(1.0), ParamValue::Int(30));
358        assert_eq!(r.sample_at(0.5), ParamValue::Int(20));
359    }
360
361    // ── cartesian_product ─────────────────────────────────────────────────────
362
363    #[test]
364    fn test_cartesian_product() {
365        let params: Vec<(&str, Vec<ParamValue>)> = vec![
366            ("a", vec![ParamValue::Int(1), ParamValue::Int(2)]),
367            ("b", vec![ParamValue::Int(10), ParamValue::Int(20)]),
368        ];
369        let combos = cartesian_product(&params);
370        assert_eq!(combos.len(), 4);
371    }
372
373    // ── GridSearch integration ────────────────────────────────────────────────
374
375    #[test]
376    fn test_grid_search_runs() {
377        let prices = trending_prices(100);
378        let candles = make_candles(&prices);
379        let config = BacktestConfig::builder()
380            .commission_pct(0.0)
381            .slippage_pct(0.0)
382            .build()
383            .unwrap();
384
385        let report = GridSearch::new()
386            .param("fast", ParamRange::int_range(3, 10, 3))
387            .param("slow", ParamRange::int_range(10, 20, 10))
388            .optimize_for(OptimizeMetric::TotalReturn)
389            .run("TEST", &candles, &config, |params| {
390                SmaCrossover::new(
391                    params["fast"].as_int() as usize,
392                    params["slow"].as_int() as usize,
393                )
394            })
395            .unwrap();
396
397        assert!(!report.results.is_empty());
398        assert_eq!(report.strategy_name, "SMA Crossover");
399        assert!(
400            report.convergence_curve.is_empty(),
401            "GridSearch curve should be empty"
402        );
403        assert_eq!(report.n_evaluations, report.total_combinations);
404
405        if report.results.len() > 1 {
406            let first = OptimizeMetric::TotalReturn.score(&report.results[0].result);
407            let second = OptimizeMetric::TotalReturn.score(&report.results[1].result);
408            assert!(first >= second);
409        }
410    }
411
412    #[test]
413    fn test_grid_search_no_params_errors() {
414        let candles = make_candles(&trending_prices(50));
415        let config = BacktestConfig::default();
416        let result = GridSearch::new().run("TEST", &candles, &config, |_| SmaCrossover::new(5, 10));
417        assert!(result.is_err());
418    }
419
420    #[test]
421    fn test_grid_search_float_bounds_errors() {
422        // float_bounds is incompatible with GridSearch (step=0.0 → empty expansion).
423        let candles = make_candles(&trending_prices(100));
424        let config = BacktestConfig::default();
425        let result = GridSearch::new()
426            .param("x", ParamRange::float_bounds(0.1, 0.9))
427            .run("TEST", &candles, &config, |_| SmaCrossover::new(5, 20));
428        assert!(result.is_err());
429    }
430
431    #[test]
432    fn test_optimize_metric_min_drawdown() {
433        let prices = trending_prices(60);
434        let candles = make_candles(&prices);
435        let config = BacktestConfig::builder()
436            .commission_pct(0.0)
437            .slippage_pct(0.0)
438            .build()
439            .unwrap();
440
441        let report = GridSearch::new()
442            .param("fast", ParamRange::int_range(3, 9, 3))
443            .param("slow", ParamRange::int_range(10, 20, 10))
444            .optimize_for(OptimizeMetric::MinDrawdown)
445            .run("TEST", &candles, &config, |params| {
446                SmaCrossover::new(
447                    params["fast"].as_int() as usize,
448                    params["slow"].as_int() as usize,
449                )
450            })
451            .unwrap();
452
453        assert!(!report.results.is_empty());
454        if report.results.len() > 1 {
455            let first = report.results[0].result.metrics.max_drawdown_pct;
456            let second = report.results[1].result.metrics.max_drawdown_pct;
457            assert!(first <= second + 1e-9, "best has smallest drawdown");
458        }
459    }
460}