Skip to main content

quantwave_backtest/
walk_forward.rs

1//! Walk-forward out-of-sample validation (quantwave-cr6v.14 / quantwave-xibc).
2//!
3//! Clean-room rolling OOS folds on pre-computed signals (RaptorBT / Zorro WFO pattern).
4//! v1: no in-fold parameter optimization — each fold backtests the OOS window only.
5
6use crate::{BacktestConfig, BacktestEngine, BacktestError, PerformanceMetrics};
7use polars::prelude::*;
8use std::collections::HashMap;
9
10/// Rolling walk-forward configuration (bar counts on the unique timestamp index).
11#[derive(Debug, Clone, PartialEq)]
12pub struct WalkForwardConfig {
13    /// In-sample warmup bars (skipped for OOS metrics; advances the window).
14    pub train_bars: usize,
15    /// Out-of-sample bars backtested per fold.
16    pub test_bars: usize,
17    /// Step between folds (defaults to `test_bars`).
18    pub step_bars: Option<usize>,
19    pub overfit_threshold: f64,
20}
21
22impl WalkForwardConfig {
23    pub fn new(train_bars: usize, test_bars: usize) -> Self {
24        Self {
25            train_bars,
26            test_bars,
27            step_bars: None,
28            overfit_threshold: 1.0,
29        }
30    }
31
32    fn step(&self) -> usize {
33        self.step_bars.unwrap_or(self.test_bars).max(1)
34    }
35}
36
37/// Run walk-forward OOS backtests; returns fold × metrics DataFrame.
38pub fn run_walk_forward(
39    lf: LazyFrame,
40    base_config: &BacktestConfig,
41    wf: &WalkForwardConfig,
42) -> Result<DataFrame, BacktestError> {
43    if wf.train_bars == 0 || wf.test_bars == 0 {
44        return Err(BacktestError::InvalidInput(
45            "train_bars and test_bars must be > 0".into(),
46        ));
47    }
48
49    let df = lf.collect()?;
50    if df.height() == 0 {
51        return Err(BacktestError::InvalidInput("empty dataframe".into()));
52    }
53
54    let ts_col = &base_config.timestamp_col;
55    let timestamps = unique_sorted_timestamps(&df, ts_col)?;
56    let step = wf.step();
57    let mut fold_id = 0usize;
58    let mut fold_ids = Vec::new();
59    let mut oos_start = Vec::new();
60    let mut oos_end = Vec::new();
61    let mut train_lens = Vec::new();
62    let mut test_lens = Vec::new();
63    let mut metric_cols: HashMap<&'static str, Vec<f64>> = PerformanceMetrics::column_names()
64        .iter()
65        .map(|&n| (n, Vec::new()))
66        .collect();
67
68    let mut start = 0usize;
69    while start + wf.train_bars + wf.test_bars <= timestamps.len() {
70        let test_start_idx = start + wf.train_bars;
71        let test_end_idx = test_start_idx + wf.test_bars;
72        let ts_min = timestamps[test_start_idx];
73        let ts_max = timestamps[test_end_idx - 1];
74
75        let oos_lf = df
76            .clone()
77            .lazy()
78            .filter(col(ts_col).gt_eq(lit(ts_min)).and(col(ts_col).lt_eq(lit(ts_max))));
79
80        let report = BacktestEngine::new(base_config.clone()).backtest_with_report(oos_lf)?;
81
82        fold_ids.push(fold_id as f64);
83        oos_start.push(ts_min as f64);
84        oos_end.push(ts_max as f64);
85        train_lens.push(wf.train_bars as f64);
86        test_lens.push(wf.test_bars as f64);
87        for (name, value) in report.metrics.row_iter() {
88            metric_cols.get_mut(name).unwrap().push(value);
89        }
90
91        fold_id += 1;
92        start += step;
93    }
94
95    if fold_ids.is_empty() {
96        return Err(BacktestError::InvalidInput(format!(
97            "insufficient bars for walk-forward: need >= {} unique timestamps, got {}",
98            wf.train_bars + wf.test_bars,
99            timestamps.len()
100        )));
101    }
102
103    let mut columns = vec![
104        Column::new("fold_id".into(), fold_ids),
105        Column::new("oos_start_ts".into(), oos_start),
106        Column::new("oos_end_ts".into(), oos_end),
107        Column::new("train_bars".into(), train_lens),
108        Column::new("test_bars".into(), test_lens),
109    ];
110    for name in PerformanceMetrics::column_names() {
111        columns.push(Column::new(
112            PlSmallStr::from_str(name),
113            metric_cols.remove(name).unwrap(),
114        ));
115    }
116
117    DataFrame::new(columns).map_err(BacktestError::from)
118}
119
120fn unique_sorted_timestamps(df: &DataFrame, ts_col: &str) -> Result<Vec<i64>, BacktestError> {
121    let ts = df
122        .column(ts_col)
123        .map_err(|e| BacktestError::InvalidInput(e.to_string()))?;
124    let mut values: Vec<i64> = match ts.dtype() {
125        DataType::Int64 => ts.i64().unwrap().into_iter().flatten().collect(),
126        DataType::Int32 => ts
127            .i32()
128            .unwrap()
129            .into_iter()
130            .flatten()
131            .map(|v| v as i64)
132            .collect(),
133        other => {
134            return Err(BacktestError::InvalidInput(format!(
135                "timestamp column must be Int64/Int32, got {other:?}"
136            )));
137        }
138    };
139    values.sort_unstable();
140    values.dedup();
141    Ok(values)
142}
143
144/// Run walk-forward optimization: sweep on train fold, pick best by objective, backtest OOS.
145pub fn run_walk_forward_optimize(
146    lf: LazyFrame,
147    base_config: &BacktestConfig,
148    wf: &WalkForwardConfig,
149    variants: &[crate::SweepVariant],
150    objective_metric: &str,
151) -> Result<DataFrame, BacktestError> {
152    if wf.train_bars == 0 || wf.test_bars == 0 {
153        return Err(BacktestError::InvalidInput("train/test_bars must be > 0".into()));
154    }
155    if variants.is_empty() {
156        return Err(BacktestError::InvalidInput("at least one variant required".into()));
157    }
158
159    let df = lf.collect()?;
160    if df.height() == 0 {
161        return Err(BacktestError::InvalidInput("empty dataframe".into()));
162    }
163
164    let ts_col = &base_config.timestamp_col;
165    let timestamps = unique_sorted_timestamps(&df, ts_col)?;
166    let step = wf.step();
167    let param_keys = crate::sweep::sorted_param_keys(variants);
168    
169    let mut fold_ids = Vec::new();
170    let mut oos_starts = Vec::new();
171    let mut oos_ends = Vec::new();
172    let mut train_metrics = Vec::new();
173    let mut oos_metrics = Vec::new();
174    let mut overfit_flags = Vec::new();
175    let mut best_params: HashMap<String, Vec<f64>> = param_keys.iter().map(|k| (k.clone(), Vec::new())).collect();
176    
177    let mut metric_cols: HashMap<&'static str, Vec<f64>> = PerformanceMetrics::column_names()
178        .iter().map(|&n| (n, Vec::new())).collect();
179
180    let mut start = 0usize;
181    let mut fold_id = 0usize;
182    while start + wf.train_bars + wf.test_bars <= timestamps.len() {
183        let test_start_idx = start + wf.train_bars;
184        let test_end_idx = test_start_idx + wf.test_bars;
185        let ts_train_start = timestamps[start];
186        let ts_train_end = timestamps[test_start_idx - 1];
187        let ts_oos_start = timestamps[test_start_idx];
188        let ts_oos_end = timestamps[test_end_idx - 1];
189
190        // 1. Train Sweep
191        let train_lf = df.clone().lazy()
192            .filter(col(ts_col).gt_eq(lit(ts_train_start)).and(col(ts_col).lt_eq(lit(ts_train_end))));
193        let sweep_df = crate::sweep::run_param_sweep(train_lf, variants, base_config)?;
194        
195        // Pick best variant
196        let obj_col = sweep_df.column(objective_metric).map_err(|e| BacktestError::InvalidInput(format!("objective_metric not found: {e}")))?;
197        let obj_series = obj_col.f64().map_err(|e| BacktestError::InvalidInput(e.to_string()))?;
198        
199        let mut best_idx = 0;
200        let mut best_val = f64::NEG_INFINITY;
201        for (i, val) in obj_series.into_iter().enumerate() {
202            if let Some(v) = val {
203                if v > best_val || (best_val == f64::NEG_INFINITY && v.is_finite()) {
204                    best_val = v;
205                    best_idx = i;
206                }
207            }
208        }
209        
210        let winning_variant = &variants[best_idx];
211        for k in &param_keys {
212            best_params.get_mut(k).unwrap().push(winning_variant.params[k]);
213        }
214        train_metrics.push(best_val);
215        
216        // 2. OOS Backtest
217        let oos_lf = df.clone().lazy()
218            .filter(col(ts_col).gt_eq(lit(ts_oos_start)).and(col(ts_col).lt_eq(lit(ts_oos_end))));
219            
220        let mut oos_config = base_config.clone();
221        oos_config.signal_col = winning_variant.signal_col.clone();
222        let report = BacktestEngine::new(oos_config).backtest_with_report(oos_lf)?;
223        
224        let oos_val = report.metrics.row_iter().find(|(n, _)| *n == objective_metric).unwrap().1;
225        oos_metrics.push(oos_val);
226        overfit_flags.push(best_val - oos_val > wf.overfit_threshold);
227        
228        for (name, value) in report.metrics.row_iter() {
229            metric_cols.get_mut(name).unwrap().push(value);
230        }
231        
232        fold_ids.push(fold_id as f64);
233        oos_starts.push(ts_oos_start as f64);
234        oos_ends.push(ts_oos_end as f64);
235        
236        fold_id += 1;
237        start += step;
238    }
239
240    if fold_ids.is_empty() {
241        return Err(BacktestError::InvalidInput("insufficient bars for wfo".into()));
242    }
243
244    let mut columns = vec![
245        Column::new("fold_id".into(), fold_ids),
246        Column::new("oos_start_ts".into(), oos_starts),
247        Column::new("oos_end_ts".into(), oos_ends),
248        Column::new("train_metric".into(), train_metrics),
249        Column::new("oos_metric".into(), oos_metrics),
250        Column::new("overfit_flag".into(), overfit_flags),
251    ];
252    for k in &param_keys {
253        columns.push(Column::new(format!("best_{k}").into(), best_params.remove(k).unwrap()));
254    }
255    for name in PerformanceMetrics::column_names() {
256        columns.push(Column::new(PlSmallStr::from_str(name), metric_cols.remove(name).unwrap()));
257    }
258
259    DataFrame::new(columns).map_err(BacktestError::from)
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use approx::assert_relative_eq;
266
267    fn wf_base_df(n: usize) -> DataFrame {
268        DataFrame::new(vec![
269            Column::new(
270                "timestamp".into(),
271                (0..n as i64).map(|i| 1_700_000_000 + i * 3600).collect::<Vec<_>>(),
272            ),
273            Column::new(
274                "close".into(),
275                (0..n).map(|i| 100.0 + i as f64 * 0.1).collect::<Vec<_>>(),
276            ),
277            Column::new(
278                "signal".into(),
279                (0..n)
280                    .map(|i| if (i / 20) % 2 == 0 { 1.0 } else { 0.0 })
281                    .collect::<Vec<_>>(),
282            ),
283        ])
284        .unwrap()
285    }
286
287    fn zero_cost_config() -> BacktestConfig {
288        BacktestConfig {
289            cost_model: crate::CostModel {
290                commission_bps: 0.0,
291                slippage_bps: 0.0,
292                initial_cash: 100_000.0,
293            },
294            ..Default::default()
295        }
296    }
297
298    #[test]
299    fn test_walk_forward_produces_two_folds() {
300        let wf = WalkForwardConfig::new(30, 20);
301        let df = run_walk_forward(
302            wf_base_df(100).lazy(),
303            &zero_cost_config(),
304            &wf,
305        )
306        .unwrap();
307
308        // 100 unique bars, train=30, test=20, step=20 → folds at 0, 20, 40
309        assert_eq!(df.height(), 3);
310        assert!(df.column("fold_id").is_ok());
311        assert!(df.column("num_trades").is_ok());
312        assert_relative_eq!(
313            df.column("fold_id").unwrap().f64().unwrap().get(2).unwrap(),
314            2.0,
315            epsilon = 1e-9
316        );
317    }
318
319    #[test]
320    fn test_walk_forward_insufficient_bars_errors() {
321        let wf = WalkForwardConfig::new(50, 50);
322        let err = run_walk_forward(wf_base_df(60).lazy(), &zero_cost_config(), &wf)
323            .unwrap_err()
324            .to_string();
325        assert!(err.contains("insufficient bars"));
326    }
327
328    #[test]
329    fn test_walk_forward_oos_windows_do_not_overlap_when_step_equals_test() {
330        let wf = WalkForwardConfig::new(20, 15);
331        let df = run_walk_forward(wf_base_df(80).lazy(), &zero_cost_config(), &wf).unwrap();
332        let starts = df.column("oos_start_ts").unwrap().f64().unwrap();
333        let ends = df.column("oos_end_ts").unwrap().f64().unwrap();
334        for i in 0..df.height() - 1 {
335            assert!(ends.get(i).unwrap() < starts.get(i + 1).unwrap());
336        }
337    }
338
339    fn wfo_base_df(n: usize) -> DataFrame {
340        // Create an explicit pattern: signal_A is good in first half (train), bad in second (OOS).
341        // signal_B is bad in first half, good in second half.
342        let mut close = vec![100.0; n];
343        let mut signal_a = vec![0.0; n];
344        let mut signal_b = vec![0.0; n];
345        
346        for i in 1..n {
347            if i < n / 2 {
348                // First half: A makes money, B loses
349                signal_a[i] = 1.0;
350                signal_b[i] = -1.0;
351                close[i] = close[i - 1] + 1.0;
352            } else {
353                // Second half: A loses, B makes money
354                signal_a[i] = 1.0;
355                signal_b[i] = -1.0;
356                close[i] = close[i - 1] - 1.0;
357            }
358        }
359        
360        DataFrame::new(vec![
361            Column::new("timestamp".into(), (0..n as i64).collect::<Vec<_>>()),
362            Column::new("close".into(), close),
363            Column::new("signal_A".into(), signal_a),
364            Column::new("signal_B".into(), signal_b),
365        ]).unwrap()
366    }
367
368    #[test]
369    fn test_wfo_opt_picks_higher_sharpe_param_on_train() {
370        let wf = WalkForwardConfig::new(20, 20); // 20 train, 20 oos (total 40 bars)
371        let df = wfo_base_df(40);
372        let variants = vec![
373            crate::SweepVariant { params: std::collections::HashMap::from([("param".into(), 1.0)]), signal_col: "signal_A".into() },
374            crate::SweepVariant { params: std::collections::HashMap::from([("param".into(), 2.0)]), signal_col: "signal_B".into() },
375        ];
376        
377        let out = run_walk_forward_optimize(df.lazy(), &zero_cost_config(), &wf, &variants, "total_return").unwrap();
378        
379        assert_eq!(out.height(), 1);
380        let best_param = out.column("best_param").unwrap().f64().unwrap().get(0).unwrap();
381        // In train (0..20), A is profitable, so param 1.0 should be chosen
382        assert_eq!(best_param, 1.0);
383    }
384
385    #[test]
386    fn test_wfo_opt_oos_uses_locked_param_not_reoptimized() {
387        let wf = WalkForwardConfig::new(20, 20);
388        let df = wfo_base_df(40);
389        let variants = vec![
390            crate::SweepVariant { params: std::collections::HashMap::from([("param".into(), 1.0)]), signal_col: "signal_A".into() },
391            crate::SweepVariant { params: std::collections::HashMap::from([("param".into(), 2.0)]), signal_col: "signal_B".into() },
392        ];
393        let out = run_walk_forward_optimize(df.lazy(), &zero_cost_config(), &wf, &variants, "total_return").unwrap();
394        
395        let oos_metric = out.column("oos_metric").unwrap().f64().unwrap().get(0).unwrap();
396        // In OOS (20..40), A loses money, so total_return should be negative
397        assert!(oos_metric < 0.0);
398    }
399
400    #[test]
401    fn test_wfo_opt_overfit_flag_when_train_oos_diverge() {
402        let mut wf = WalkForwardConfig::new(20, 20);
403        wf.overfit_threshold = 0.0; // PnL is very small due to 1 unit position
404        let df = wfo_base_df(40);
405        let variants = vec![
406            crate::SweepVariant { params: std::collections::HashMap::from([("p".into(), 1.0)]), signal_col: "signal_A".into() },
407        ];
408        let out = run_walk_forward_optimize(df.lazy(), &zero_cost_config(), &wf, &variants, "total_return").unwrap();
409        
410        let overfit = out.column("overfit_flag").unwrap().bool().unwrap().get(0).unwrap();
411        // Train return > 0, OOS return < 0, difference is large
412        assert!(overfit);
413    }
414
415    #[test]
416    fn test_wfo_opt_fold_count_matches_walk_forward() {
417        let wf = WalkForwardConfig::new(20, 10);
418        let df = wfo_base_df(60);
419        let variants = vec![
420            crate::SweepVariant { params: std::collections::HashMap::from([("p".into(), 1.0)]), signal_col: "signal_A".into() },
421        ];
422        let mut cfg = zero_cost_config();
423        cfg.signal_col = "signal_A".into();
424        let out1 = run_walk_forward(df.clone().lazy(), &cfg, &wf).unwrap();
425        let out2 = run_walk_forward_optimize(df.lazy(), &zero_cost_config(), &wf, &variants, "total_return").unwrap();
426        
427        assert_eq!(out1.height(), out2.height());
428    }
429}