Skip to main content

finance_query/backtesting/
walk_forward.rs

1//! Walk-forward parameter optimisation for backtesting strategies.
2//!
3//! Walk-forward testing prevents overfitting by splitting historical data into
4//! rolling in-sample (training) and out-of-sample (test) windows. For each
5//! window, the best parameters are discovered on the in-sample slice via grid
6//! search, then validated on the subsequent out-of-sample slice.
7//!
8//! # How it works
9//!
10//! ```text
11//! |--- in-sample (IS) ---|--- out-of-sample (OOS) ---|
12//!            |-- step --|--- IS ---|--- OOS ---|
13//!                                  |-- step --|--- IS ---|--- OOS ---|
14//! ```
15//!
16//! Aggregate metrics from all OOS windows provide an unbiased estimate of
17//! real-world strategy performance.
18//!
19//! # Example
20//!
21//! ```ignore
22//! use finance_query::backtesting::{
23//!     BacktestConfig, SmaCrossover,
24//!     optimizer::{GridSearch, OptimizeMetric, ParamRange},
25//!     walk_forward::WalkForwardConfig,
26//! };
27//!
28//! # fn example(candles: &[finance_query::models::chart::Candle]) {
29//! let grid = GridSearch::new()
30//!     .param("fast", ParamRange::int_range(5, 30, 5))
31//!     .param("slow", ParamRange::int_range(20, 100, 10))
32//!     .optimize_for(OptimizeMetric::SharpeRatio);
33//!
34//! let wf = WalkForwardConfig::new(grid, BacktestConfig::default())
35//!     .in_sample_bars(252)
36//!     .out_of_sample_bars(63);
37//!
38//! let report = wf
39//!     .run("AAPL", candles, |params| SmaCrossover::new(
40//!         params["fast"].as_int() as usize,
41//!         params["slow"].as_int() as usize,
42//!     ))
43//!     .unwrap();
44//!
45//! println!("OOS consistency: {:.1}%", report.consistency_ratio * 100.0);
46//! # }
47//! ```
48
49use std::collections::HashMap;
50
51use serde::{Deserialize, Serialize};
52
53use crate::models::chart::Candle;
54
55use super::config::BacktestConfig;
56use super::error::{BacktestError, Result};
57use super::optimizer::{GridSearch, OptimizationReport, ParamValue};
58use super::result::{BacktestResult, PerformanceMetrics};
59use super::strategy::Strategy;
60
61// ── Result types ─────────────────────────────────────────────────────────────
62
63/// Backtest results for a single walk-forward window pair.
64#[non_exhaustive]
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct WindowResult {
67    /// Zero-based window index
68    pub window: usize,
69    /// Parameter values selected as best on the in-sample data
70    pub optimized_params: HashMap<String, ParamValue>,
71    /// In-sample backtest result (using the best parameters)
72    pub in_sample: BacktestResult,
73    /// Out-of-sample backtest result (using the same best parameters)
74    pub out_of_sample: BacktestResult,
75}
76
77/// Aggregate walk-forward report across all windows.
78#[non_exhaustive]
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct WalkForwardReport {
81    /// Strategy name
82    pub strategy_name: String,
83    /// Per-window results
84    pub windows: Vec<WindowResult>,
85    /// Aggregate performance metrics computed from the concatenated OOS equity curves
86    pub aggregate_metrics: PerformanceMetrics,
87    /// Fraction of OOS windows that were profitable (0.0 – 1.0)
88    pub consistency_ratio: f64,
89    /// Full grid-search optimisation reports, one per window
90    pub optimization_reports: Vec<OptimizationReport>,
91}
92
93// ── WalkForwardConfig ─────────────────────────────────────────────────────────
94
95/// Configuration for a walk-forward parameter optimisation test.
96///
97/// Build with [`WalkForwardConfig::new`], configure window sizes with the
98/// builder methods, then call [`WalkForwardConfig::run`].
99#[non_exhaustive]
100#[derive(Debug, Clone)]
101pub struct WalkForwardConfig {
102    /// Grid search to use for optimising in-sample windows
103    pub grid: GridSearch,
104    /// Base backtest configuration (capital, commission, slippage, …)
105    pub config: BacktestConfig,
106    /// Number of bars in each in-sample (training) window
107    pub in_sample_bars: usize,
108    /// Number of bars in each out-of-sample (test) window
109    pub out_of_sample_bars: usize,
110    /// Number of bars to advance the window each step.
111    ///
112    /// Defaults to `out_of_sample_bars` (non-overlapping OOS windows).
113    pub step_bars: Option<usize>,
114}
115
116impl WalkForwardConfig {
117    /// Create a new walk-forward config.
118    ///
119    /// Defaults: `in_sample_bars = 252`, `out_of_sample_bars = 63`, `step_bars = None`.
120    pub fn new(grid: GridSearch, config: BacktestConfig) -> Self {
121        Self {
122            grid,
123            config,
124            in_sample_bars: 252,
125            out_of_sample_bars: 63,
126            step_bars: None,
127        }
128    }
129
130    /// Set the number of bars for each in-sample (training) window.
131    pub fn in_sample_bars(mut self, bars: usize) -> Self {
132        self.in_sample_bars = bars;
133        self
134    }
135
136    /// Set the number of bars for each out-of-sample (test) window.
137    pub fn out_of_sample_bars(mut self, bars: usize) -> Self {
138        self.out_of_sample_bars = bars;
139        self
140    }
141
142    /// Set the step size (bars to advance between windows).
143    ///
144    /// Defaults to `out_of_sample_bars` for non-overlapping OOS windows.
145    pub fn step_bars(mut self, bars: usize) -> Self {
146        self.step_bars = Some(bars);
147        self
148    }
149
150    /// Run the walk-forward test.
151    ///
152    /// `symbol` is used only for labelling. `factory` receives the parameter
153    /// map selected by each in-sample optimisation and must return a fresh
154    /// strategy instance.
155    ///
156    /// Returns an error if there is not enough data for at least one complete
157    /// window pair, or if the grid search fails on every window.
158    pub fn run<S, F>(
159        &self,
160        symbol: &str,
161        candles: &[Candle],
162        factory: F,
163    ) -> Result<WalkForwardReport>
164    where
165        S: Strategy + Clone + Send,
166        F: Fn(&HashMap<String, ParamValue>) -> S,
167        F: Send + Sync,
168    {
169        self.validate(candles.len())?;
170
171        let step = self.step_bars.unwrap_or(self.out_of_sample_bars);
172        let total_bars = self.in_sample_bars + self.out_of_sample_bars;
173
174        // Slide the window through the candle series
175        let mut windows: Vec<WindowResult> = Vec::new();
176        let mut opt_reports: Vec<OptimizationReport> = Vec::new();
177        let mut window_idx = 0;
178        let mut start = 0;
179
180        while start + total_bars <= candles.len() {
181            let is_end = start + self.in_sample_bars;
182            let oos_end = is_end + self.out_of_sample_bars;
183
184            let is_candles = &candles[start..is_end];
185            let oos_candles = &candles[is_end..oos_end];
186
187            // Optimise on the in-sample slice
188            let opt_report = self
189                .grid
190                .run(symbol, is_candles, &self.config, &factory)
191                .map_err(|e| {
192                    BacktestError::invalid_param(
193                        "walk_forward",
194                        format!("window {window_idx} optimisation failed: {e}"),
195                    )
196                })?;
197
198            let best_params = opt_report.best.params.clone();
199            let is_result = opt_report.best.result.clone();
200
201            // Test on the out-of-sample slice using the best parameters
202            let oos_strategy = factory(&best_params);
203            let oos_result = crate::backtesting::BacktestEngine::new(self.config.clone())
204                .run(symbol, oos_candles, oos_strategy)
205                .map_err(|e| {
206                    BacktestError::invalid_param(
207                        "walk_forward",
208                        format!("window {window_idx} OOS run failed: {e}"),
209                    )
210                })?;
211
212            windows.push(WindowResult {
213                window: window_idx,
214                optimized_params: best_params,
215                in_sample: is_result,
216                out_of_sample: oos_result,
217            });
218            opt_reports.push(opt_report);
219
220            start += step;
221            window_idx += 1;
222        }
223
224        if windows.is_empty() {
225            return Err(BacktestError::invalid_param(
226                "candles",
227                "not enough data for any walk-forward window",
228            ));
229        }
230
231        let strategy_name = windows[0].in_sample.strategy_name.clone();
232        let consistency_ratio = calculate_consistency_ratio(&windows);
233        let aggregate_metrics = aggregate_oos_metrics(
234            &windows,
235            self.config.risk_free_rate,
236            self.config.bars_per_year,
237        );
238
239        Ok(WalkForwardReport {
240            strategy_name,
241            windows,
242            aggregate_metrics,
243            consistency_ratio,
244            optimization_reports: opt_reports,
245        })
246    }
247
248    /// Validate the configuration before running.
249    fn validate(&self, num_candles: usize) -> Result<()> {
250        if self.in_sample_bars == 0 {
251            return Err(BacktestError::invalid_param(
252                "in_sample_bars",
253                "must be greater than zero",
254            ));
255        }
256        if self.out_of_sample_bars == 0 {
257            return Err(BacktestError::invalid_param(
258                "out_of_sample_bars",
259                "must be greater than zero",
260            ));
261        }
262        let total_bars = self.in_sample_bars + self.out_of_sample_bars;
263        if num_candles < total_bars {
264            return Err(BacktestError::insufficient_data(total_bars, num_candles));
265        }
266        Ok(())
267    }
268}
269
270// ── Internal helpers ──────────────────────────────────────────────────────────
271
272/// Fraction of OOS windows that had a positive total P&L.
273fn calculate_consistency_ratio(windows: &[WindowResult]) -> f64 {
274    if windows.is_empty() {
275        return 0.0;
276    }
277    let profitable = windows
278        .iter()
279        .filter(|w| w.out_of_sample.is_profitable())
280        .count();
281    profitable as f64 / windows.len() as f64
282}
283
284/// Compute aggregate `PerformanceMetrics` over all OOS trade lists and equity curves.
285///
286/// Concatenates trades and stitches OOS equity curves so each window starts
287/// from the previous window's ending equity.
288fn aggregate_oos_metrics(
289    windows: &[WindowResult],
290    risk_free_rate: f64,
291    bars_per_year: f64,
292) -> PerformanceMetrics {
293    use crate::backtesting::result::EquityPoint;
294
295    let all_trades: Vec<_> = windows
296        .iter()
297        .flat_map(|w| w.out_of_sample.trades.iter().cloned())
298        .collect();
299
300    // Stitch per-window equity into one continuous compounded series.
301    // Each OOS window internally resets to its own initial capital; to avoid
302    // synthetic drawdowns between windows, scale each window by the running
303    // equity level from the previous window.
304    let mut combined_equity: Vec<EquityPoint> = Vec::new();
305    // `windows` is guaranteed non-empty by the validation above; index directly.
306    let mut running_equity = windows[0].out_of_sample.initial_capital;
307
308    for (window_idx, window) in windows.iter().enumerate() {
309        let window_initial = window.out_of_sample.initial_capital;
310        if window_initial <= 0.0 {
311            continue;
312        }
313
314        for (point_idx, point) in window.out_of_sample.equity_curve.iter().enumerate() {
315            if window_idx > 0 && point_idx == 0 {
316                continue;
317            }
318
319            let scaled_equity = running_equity * (point.equity / window_initial);
320            combined_equity.push(EquityPoint {
321                timestamp: point.timestamp,
322                equity: scaled_equity,
323                drawdown_pct: 0.0,
324            });
325        }
326
327        if let Some(last) = combined_equity.last() {
328            running_equity = last.equity;
329        }
330    }
331
332    // Recompute drawdowns on the stitched curve.
333    let mut peak = f64::NEG_INFINITY;
334    for point in &mut combined_equity {
335        peak = peak.max(point.equity);
336        point.drawdown_pct = if peak > 0.0 {
337            (peak - point.equity) / peak
338        } else {
339            0.0
340        };
341    }
342
343    // Aggregate metrics use the initial capital of the first OOS window.
344    let initial_capital = windows
345        .first()
346        .map(|w| w.out_of_sample.initial_capital)
347        .unwrap_or(10_000.0);
348
349    let total_signals: usize = windows.iter().map(|w| w.out_of_sample.signals.len()).sum();
350    let executed_signals: usize = windows
351        .iter()
352        .map(|w| {
353            w.out_of_sample
354                .signals
355                .iter()
356                .filter(|s| s.executed)
357                .count()
358        })
359        .sum();
360
361    PerformanceMetrics::calculate(
362        &all_trades,
363        &combined_equity,
364        initial_capital,
365        total_signals,
366        executed_signals,
367        risk_free_rate,
368        bars_per_year,
369    )
370}
371
372// ── Tests ─────────────────────────────────────────────────────────────────────
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use crate::backtesting::{
378        BacktestConfig, SmaCrossover,
379        optimizer::{OptimizeMetric, ParamRange},
380    };
381    use crate::models::chart::Candle;
382
383    fn make_candles(prices: &[f64]) -> Vec<Candle> {
384        prices
385            .iter()
386            .enumerate()
387            .map(|(i, &p)| Candle {
388                timestamp: i as i64,
389                open: p,
390                high: p * 1.01,
391                low: p * 0.99,
392                close: p,
393                volume: 1000,
394                adj_close: Some(p),
395            })
396            .collect()
397    }
398
399    fn trending_prices(n: usize) -> Vec<f64> {
400        (0..n).map(|i| 100.0 + i as f64 * 0.3).collect()
401    }
402
403    #[test]
404    fn test_walk_forward_basic() {
405        // 300 bars: 200 IS + 100 OOS → 1 window
406        let prices = trending_prices(300);
407        let candles = make_candles(&prices);
408        let config = BacktestConfig::builder()
409            .commission_pct(0.0)
410            .slippage_pct(0.0)
411            .build()
412            .unwrap();
413
414        let grid = GridSearch::new()
415            .param("fast", ParamRange::int_range(3, 9, 3))
416            .param("slow", ParamRange::int_range(10, 20, 10))
417            .optimize_for(OptimizeMetric::TotalReturn);
418
419        let report = WalkForwardConfig::new(grid, config)
420            .in_sample_bars(200)
421            .out_of_sample_bars(100)
422            .run("TEST", &candles, |params| {
423                SmaCrossover::new(
424                    params["fast"].as_int() as usize,
425                    params["slow"].as_int() as usize,
426                )
427            })
428            .unwrap();
429
430        assert_eq!(report.windows.len(), 1);
431        assert_eq!(report.strategy_name, "SMA Crossover");
432        assert!(report.consistency_ratio >= 0.0);
433        assert!(report.consistency_ratio <= 1.0);
434    }
435
436    #[test]
437    fn test_walk_forward_multiple_windows() {
438        // 500 bars, step = 100 OOS → 3 windows (100+100, 200+100, 300+100, 400+100)
439        let prices = trending_prices(500);
440        let candles = make_candles(&prices);
441        let config = BacktestConfig::builder()
442            .commission_pct(0.0)
443            .slippage_pct(0.0)
444            .build()
445            .unwrap();
446
447        let grid = GridSearch::new()
448            .param("fast", ParamRange::int_range(3, 6, 3))
449            .param("slow", ParamRange::int_range(10, 10, 1))
450            .optimize_for(OptimizeMetric::TotalReturn);
451
452        let report = WalkForwardConfig::new(grid, config)
453            .in_sample_bars(200)
454            .out_of_sample_bars(100)
455            .step_bars(100)
456            .run("TEST", &candles, |params| {
457                SmaCrossover::new(
458                    params["fast"].as_int() as usize,
459                    params["slow"].as_int() as usize,
460                )
461            })
462            .unwrap();
463
464        assert!(report.windows.len() >= 2);
465        assert_eq!(report.optimization_reports.len(), report.windows.len());
466    }
467
468    #[test]
469    fn test_insufficient_data_errors() {
470        let candles = make_candles(&trending_prices(50));
471        let config = BacktestConfig::default();
472        let grid = GridSearch::new()
473            .param("fast", ParamRange::int_range(3, 6, 3))
474            .param("slow", ParamRange::int_range(10, 10, 1));
475
476        let result = WalkForwardConfig::new(grid, config)
477            .in_sample_bars(200) // more than 50 candles
478            .out_of_sample_bars(100)
479            .run("TEST", &candles, |params| {
480                SmaCrossover::new(
481                    params["fast"].as_int() as usize,
482                    params["slow"].as_int() as usize,
483                )
484            });
485
486        assert!(result.is_err());
487    }
488
489    #[test]
490    fn test_consistency_ratio_all_profitable() {
491        // All windows profitable → ratio = 1.0
492        let prices: Vec<f64> = (0..300).map(|i| 100.0 + i as f64).collect();
493        let candles = make_candles(&prices);
494        let config = BacktestConfig::builder()
495            .commission_pct(0.0)
496            .slippage_pct(0.0)
497            .build()
498            .unwrap();
499
500        let grid = GridSearch::new()
501            .param("fast", ParamRange::int_range(3, 3, 1))
502            .param("slow", ParamRange::int_range(10, 10, 1))
503            .optimize_for(OptimizeMetric::TotalReturn);
504
505        let report = WalkForwardConfig::new(grid, config)
506            .in_sample_bars(150)
507            .out_of_sample_bars(100)
508            .run("TEST", &candles, |params| {
509                SmaCrossover::new(
510                    params["fast"].as_int() as usize,
511                    params["slow"].as_int() as usize,
512                )
513            })
514            .unwrap();
515
516        // With a strong uptrend, the OOS window should be profitable
517        assert!(report.consistency_ratio >= 0.0);
518    }
519
520    #[test]
521    fn test_aggregate_equity_timestamps_are_monotonic() {
522        // With 3+ OOS windows, timestamps in the aggregated equity curve must
523        // be strictly increasing
524        let prices: Vec<f64> = (0..600).map(|i| 100.0 + (i as f64) * 0.5).collect();
525        let candles = make_candles(&prices);
526        let config = BacktestConfig::builder()
527            .commission_pct(0.0)
528            .slippage_pct(0.0)
529            .build()
530            .unwrap();
531
532        let grid = GridSearch::new()
533            .param("fast", ParamRange::int_range(3, 3, 1))
534            .param("slow", ParamRange::int_range(10, 10, 1))
535            .optimize_for(OptimizeMetric::TotalReturn);
536
537        let report = WalkForwardConfig::new(grid, config)
538            .in_sample_bars(100)
539            .out_of_sample_bars(50)
540            .run("TEST", &candles, |params| {
541                SmaCrossover::new(
542                    params["fast"].as_int() as usize,
543                    params["slow"].as_int() as usize,
544                )
545            })
546            .unwrap();
547
548        // Verify timestamps in aggregate metrics equity curve are strictly increasing
549        let curve = &report.aggregate_metrics;
550        // We verify indirectly: there must be at least 2 windows
551        assert!(
552            report.windows.len() >= 2,
553            "Expected multiple windows for timestamp test"
554        );
555
556        // Also check the combined OOS timestamps from windows directly
557        let timestamps: Vec<i64> = report
558            .windows
559            .iter()
560            .flat_map(|w| w.out_of_sample.equity_curve.iter().map(|ep| ep.timestamp))
561            .collect();
562
563        // Each window's timestamps should be internally monotonic
564        for window in &report.windows {
565            let ts: Vec<i64> = window
566                .out_of_sample
567                .equity_curve
568                .iter()
569                .map(|ep| ep.timestamp)
570                .collect();
571            for pair in ts.windows(2) {
572                assert!(
573                    pair[0] < pair[1],
574                    "Equity curve timestamps not strictly increasing within window: {} >= {}",
575                    pair[0],
576                    pair[1]
577                );
578            }
579        }
580
581        // Suppress unused variable warning
582        let _ = curve;
583        let _ = timestamps;
584    }
585
586    #[test]
587    fn test_aggregate_oos_equity_timestamps_are_gapless_across_windows() {
588        // The aggregated equity curve produced by aggregate_oos_metrics must carry
589        // the real OOS candle timestamps so that time-in-market calculations
590        // (which divide trade duration_secs by backtest_secs) use a consistent
591        // unit. Previously timestamps were replaced with auto-incrementing integers
592        // (0,1,2,...) which caused the denominator to be "N bars" instead of
593        // "N seconds", inflating time_in_market to 1.0 on any real-world data.
594        let prices: Vec<f64> = (0..600).map(|i| 100.0 + (i as f64) * 0.5).collect();
595        let candles = make_candles(&prices);
596        let config = BacktestConfig::builder()
597            .commission_pct(0.0)
598            .slippage_pct(0.0)
599            .build()
600            .unwrap();
601
602        let grid = GridSearch::new()
603            .param("fast", ParamRange::int_range(3, 3, 1))
604            .param("slow", ParamRange::int_range(10, 10, 1))
605            .optimize_for(OptimizeMetric::TotalReturn);
606
607        let report = WalkForwardConfig::new(grid, config)
608            .in_sample_bars(100)
609            .out_of_sample_bars(50)
610            .run("TEST", &candles, |params| {
611                SmaCrossover::new(
612                    params["fast"].as_int() as usize,
613                    params["slow"].as_int() as usize,
614                )
615            })
616            .unwrap();
617
618        assert!(
619            report.windows.len() >= 2,
620            "Need at least 2 OOS windows for this test"
621        );
622
623        // Collect the combined timestamps as produced by aggregate_oos_metrics.
624        // They must be strictly increasing (real candle timestamps, not bar indices).
625        let combined_ts: Vec<i64> = report
626            .windows
627            .iter()
628            .enumerate()
629            .flat_map(|(wi, w)| {
630                w.out_of_sample
631                    .equity_curve
632                    .iter()
633                    .enumerate()
634                    .filter(move |&(pi, _)| !(wi > 0 && pi == 0))
635                    .map(|(_, ep)| ep.timestamp)
636            })
637            .collect();
638
639        for pair in combined_ts.windows(2) {
640            assert!(
641                pair[0] < pair[1],
642                "Combined equity curve timestamps not strictly increasing: {} >= {}",
643                pair[0],
644                pair[1]
645            );
646        }
647
648        // Timestamps must reflect real candle timestamps — the first combined
649        // timestamp should match the first OOS window's first equity point.
650        let expected_first = report
651            .windows
652            .first()
653            .and_then(|w| w.out_of_sample.equity_curve.first())
654            .map(|ep| ep.timestamp)
655            .unwrap_or(0);
656        assert_eq!(
657            combined_ts.first().copied().unwrap_or(-1),
658            expected_first,
659            "First combined timestamp should equal the first OOS equity point timestamp"
660        );
661    }
662}