1use crate::{BacktestConfig, BacktestEngine, BacktestError, PerformanceMetrics};
7use polars::prelude::*;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct WalkForwardConfig {
13 pub train_bars: usize,
15 pub test_bars: usize,
17 pub step_bars: Option<usize>,
19 pub overfit_threshold: f64,
20}
21
22impl WalkForwardConfig {
23 pub fn new(train_bars: usize, test_bars: usize) -> Self {
24 Self {
25 train_bars,
26 test_bars,
27 step_bars: None,
28 overfit_threshold: 1.0,
29 }
30 }
31
32 fn step(&self) -> usize {
33 self.step_bars.unwrap_or(self.test_bars).max(1)
34 }
35}
36
37pub fn run_walk_forward(
39 lf: LazyFrame,
40 base_config: &BacktestConfig,
41 wf: &WalkForwardConfig,
42) -> Result<DataFrame, BacktestError> {
43 if wf.train_bars == 0 || wf.test_bars == 0 {
44 return Err(BacktestError::InvalidInput(
45 "train_bars and test_bars must be > 0".into(),
46 ));
47 }
48
49 let df = lf.collect()?;
50 if df.height() == 0 {
51 return Err(BacktestError::InvalidInput("empty dataframe".into()));
52 }
53
54 let ts_col = &base_config.timestamp_col;
55 let timestamps = unique_sorted_timestamps(&df, ts_col)?;
56 let step = wf.step();
57 let mut fold_id = 0usize;
58 let mut fold_ids = Vec::new();
59 let mut oos_start = Vec::new();
60 let mut oos_end = Vec::new();
61 let mut train_lens = Vec::new();
62 let mut test_lens = Vec::new();
63 let mut metric_cols: HashMap<&'static str, Vec<f64>> = PerformanceMetrics::column_names()
64 .iter()
65 .map(|&n| (n, Vec::new()))
66 .collect();
67
68 let mut start = 0usize;
69 while start + wf.train_bars + wf.test_bars <= timestamps.len() {
70 let test_start_idx = start + wf.train_bars;
71 let test_end_idx = test_start_idx + wf.test_bars;
72 let ts_min = timestamps[test_start_idx];
73 let ts_max = timestamps[test_end_idx - 1];
74
75 let oos_lf = df
76 .clone()
77 .lazy()
78 .filter(col(ts_col).gt_eq(lit(ts_min)).and(col(ts_col).lt_eq(lit(ts_max))));
79
80 let report = BacktestEngine::new(base_config.clone()).backtest_with_report(oos_lf)?;
81
82 fold_ids.push(fold_id as f64);
83 oos_start.push(ts_min as f64);
84 oos_end.push(ts_max as f64);
85 train_lens.push(wf.train_bars as f64);
86 test_lens.push(wf.test_bars as f64);
87 for (name, value) in report.metrics.row_iter() {
88 metric_cols.get_mut(name).unwrap().push(value);
89 }
90
91 fold_id += 1;
92 start += step;
93 }
94
95 if fold_ids.is_empty() {
96 return Err(BacktestError::InvalidInput(format!(
97 "insufficient bars for walk-forward: need >= {} unique timestamps, got {}",
98 wf.train_bars + wf.test_bars,
99 timestamps.len()
100 )));
101 }
102
103 let mut columns = vec![
104 Column::new("fold_id".into(), fold_ids),
105 Column::new("oos_start_ts".into(), oos_start),
106 Column::new("oos_end_ts".into(), oos_end),
107 Column::new("train_bars".into(), train_lens),
108 Column::new("test_bars".into(), test_lens),
109 ];
110 for name in PerformanceMetrics::column_names() {
111 columns.push(Column::new(
112 PlSmallStr::from_str(name),
113 metric_cols.remove(name).unwrap(),
114 ));
115 }
116
117 DataFrame::new(columns).map_err(BacktestError::from)
118}
119
120fn unique_sorted_timestamps(df: &DataFrame, ts_col: &str) -> Result<Vec<i64>, BacktestError> {
121 let ts = df
122 .column(ts_col)
123 .map_err(|e| BacktestError::InvalidInput(e.to_string()))?;
124 let mut values: Vec<i64> = match ts.dtype() {
125 DataType::Int64 => ts.i64().unwrap().into_iter().flatten().collect(),
126 DataType::Int32 => ts
127 .i32()
128 .unwrap()
129 .into_iter()
130 .flatten()
131 .map(|v| v as i64)
132 .collect(),
133 other => {
134 return Err(BacktestError::InvalidInput(format!(
135 "timestamp column must be Int64/Int32, got {other:?}"
136 )));
137 }
138 };
139 values.sort_unstable();
140 values.dedup();
141 Ok(values)
142}
143
144pub fn run_walk_forward_optimize(
146 lf: LazyFrame,
147 base_config: &BacktestConfig,
148 wf: &WalkForwardConfig,
149 variants: &[crate::SweepVariant],
150 objective_metric: &str,
151) -> Result<DataFrame, BacktestError> {
152 if wf.train_bars == 0 || wf.test_bars == 0 {
153 return Err(BacktestError::InvalidInput("train/test_bars must be > 0".into()));
154 }
155 if variants.is_empty() {
156 return Err(BacktestError::InvalidInput("at least one variant required".into()));
157 }
158
159 let df = lf.collect()?;
160 if df.height() == 0 {
161 return Err(BacktestError::InvalidInput("empty dataframe".into()));
162 }
163
164 let ts_col = &base_config.timestamp_col;
165 let timestamps = unique_sorted_timestamps(&df, ts_col)?;
166 let step = wf.step();
167 let param_keys = crate::sweep::sorted_param_keys(variants);
168
169 let mut fold_ids = Vec::new();
170 let mut oos_starts = Vec::new();
171 let mut oos_ends = Vec::new();
172 let mut train_metrics = Vec::new();
173 let mut oos_metrics = Vec::new();
174 let mut overfit_flags = Vec::new();
175 let mut best_params: HashMap<String, Vec<f64>> = param_keys.iter().map(|k| (k.clone(), Vec::new())).collect();
176
177 let mut metric_cols: HashMap<&'static str, Vec<f64>> = PerformanceMetrics::column_names()
178 .iter().map(|&n| (n, Vec::new())).collect();
179
180 let mut start = 0usize;
181 let mut fold_id = 0usize;
182 while start + wf.train_bars + wf.test_bars <= timestamps.len() {
183 let test_start_idx = start + wf.train_bars;
184 let test_end_idx = test_start_idx + wf.test_bars;
185 let ts_train_start = timestamps[start];
186 let ts_train_end = timestamps[test_start_idx - 1];
187 let ts_oos_start = timestamps[test_start_idx];
188 let ts_oos_end = timestamps[test_end_idx - 1];
189
190 let train_lf = df.clone().lazy()
192 .filter(col(ts_col).gt_eq(lit(ts_train_start)).and(col(ts_col).lt_eq(lit(ts_train_end))));
193 let sweep_df = crate::sweep::run_param_sweep(train_lf, variants, base_config)?;
194
195 let obj_col = sweep_df.column(objective_metric).map_err(|e| BacktestError::InvalidInput(format!("objective_metric not found: {e}")))?;
197 let obj_series = obj_col.f64().map_err(|e| BacktestError::InvalidInput(e.to_string()))?;
198
199 let mut best_idx = 0;
200 let mut best_val = f64::NEG_INFINITY;
201 for (i, val) in obj_series.into_iter().enumerate() {
202 if let Some(v) = val {
203 if v > best_val || (best_val == f64::NEG_INFINITY && v.is_finite()) {
204 best_val = v;
205 best_idx = i;
206 }
207 }
208 }
209
210 let winning_variant = &variants[best_idx];
211 for k in ¶m_keys {
212 best_params.get_mut(k).unwrap().push(winning_variant.params[k]);
213 }
214 train_metrics.push(best_val);
215
216 let oos_lf = df.clone().lazy()
218 .filter(col(ts_col).gt_eq(lit(ts_oos_start)).and(col(ts_col).lt_eq(lit(ts_oos_end))));
219
220 let mut oos_config = base_config.clone();
221 oos_config.signal_col = winning_variant.signal_col.clone();
222 let report = BacktestEngine::new(oos_config).backtest_with_report(oos_lf)?;
223
224 let oos_val = report.metrics.row_iter().find(|(n, _)| *n == objective_metric).unwrap().1;
225 oos_metrics.push(oos_val);
226 overfit_flags.push(best_val - oos_val > wf.overfit_threshold);
227
228 for (name, value) in report.metrics.row_iter() {
229 metric_cols.get_mut(name).unwrap().push(value);
230 }
231
232 fold_ids.push(fold_id as f64);
233 oos_starts.push(ts_oos_start as f64);
234 oos_ends.push(ts_oos_end as f64);
235
236 fold_id += 1;
237 start += step;
238 }
239
240 if fold_ids.is_empty() {
241 return Err(BacktestError::InvalidInput("insufficient bars for wfo".into()));
242 }
243
244 let mut columns = vec![
245 Column::new("fold_id".into(), fold_ids),
246 Column::new("oos_start_ts".into(), oos_starts),
247 Column::new("oos_end_ts".into(), oos_ends),
248 Column::new("train_metric".into(), train_metrics),
249 Column::new("oos_metric".into(), oos_metrics),
250 Column::new("overfit_flag".into(), overfit_flags),
251 ];
252 for k in ¶m_keys {
253 columns.push(Column::new(format!("best_{k}").into(), best_params.remove(k).unwrap()));
254 }
255 for name in PerformanceMetrics::column_names() {
256 columns.push(Column::new(PlSmallStr::from_str(name), metric_cols.remove(name).unwrap()));
257 }
258
259 DataFrame::new(columns).map_err(BacktestError::from)
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use approx::assert_relative_eq;
266
267 fn wf_base_df(n: usize) -> DataFrame {
268 DataFrame::new(vec![
269 Column::new(
270 "timestamp".into(),
271 (0..n as i64).map(|i| 1_700_000_000 + i * 3600).collect::<Vec<_>>(),
272 ),
273 Column::new(
274 "close".into(),
275 (0..n).map(|i| 100.0 + i as f64 * 0.1).collect::<Vec<_>>(),
276 ),
277 Column::new(
278 "signal".into(),
279 (0..n)
280 .map(|i| if (i / 20) % 2 == 0 { 1.0 } else { 0.0 })
281 .collect::<Vec<_>>(),
282 ),
283 ])
284 .unwrap()
285 }
286
287 fn zero_cost_config() -> BacktestConfig {
288 BacktestConfig {
289 cost_model: crate::CostModel {
290 commission_bps: 0.0,
291 slippage_bps: 0.0,
292 initial_cash: 100_000.0,
293 },
294 ..Default::default()
295 }
296 }
297
298 #[test]
299 fn test_walk_forward_produces_two_folds() {
300 let wf = WalkForwardConfig::new(30, 20);
301 let df = run_walk_forward(
302 wf_base_df(100).lazy(),
303 &zero_cost_config(),
304 &wf,
305 )
306 .unwrap();
307
308 assert_eq!(df.height(), 3);
310 assert!(df.column("fold_id").is_ok());
311 assert!(df.column("num_trades").is_ok());
312 assert_relative_eq!(
313 df.column("fold_id").unwrap().f64().unwrap().get(2).unwrap(),
314 2.0,
315 epsilon = 1e-9
316 );
317 }
318
319 #[test]
320 fn test_walk_forward_insufficient_bars_errors() {
321 let wf = WalkForwardConfig::new(50, 50);
322 let err = run_walk_forward(wf_base_df(60).lazy(), &zero_cost_config(), &wf)
323 .unwrap_err()
324 .to_string();
325 assert!(err.contains("insufficient bars"));
326 }
327
328 #[test]
329 fn test_walk_forward_oos_windows_do_not_overlap_when_step_equals_test() {
330 let wf = WalkForwardConfig::new(20, 15);
331 let df = run_walk_forward(wf_base_df(80).lazy(), &zero_cost_config(), &wf).unwrap();
332 let starts = df.column("oos_start_ts").unwrap().f64().unwrap();
333 let ends = df.column("oos_end_ts").unwrap().f64().unwrap();
334 for i in 0..df.height() - 1 {
335 assert!(ends.get(i).unwrap() < starts.get(i + 1).unwrap());
336 }
337 }
338
339 fn wfo_base_df(n: usize) -> DataFrame {
340 let mut close = vec![100.0; n];
343 let mut signal_a = vec![0.0; n];
344 let mut signal_b = vec![0.0; n];
345
346 for i in 1..n {
347 if i < n / 2 {
348 signal_a[i] = 1.0;
350 signal_b[i] = -1.0;
351 close[i] = close[i - 1] + 1.0;
352 } else {
353 signal_a[i] = 1.0;
355 signal_b[i] = -1.0;
356 close[i] = close[i - 1] - 1.0;
357 }
358 }
359
360 DataFrame::new(vec![
361 Column::new("timestamp".into(), (0..n as i64).collect::<Vec<_>>()),
362 Column::new("close".into(), close),
363 Column::new("signal_A".into(), signal_a),
364 Column::new("signal_B".into(), signal_b),
365 ]).unwrap()
366 }
367
368 #[test]
369 fn test_wfo_opt_picks_higher_sharpe_param_on_train() {
370 let wf = WalkForwardConfig::new(20, 20); let df = wfo_base_df(40);
372 let variants = vec![
373 crate::SweepVariant { params: std::collections::HashMap::from([("param".into(), 1.0)]), signal_col: "signal_A".into() },
374 crate::SweepVariant { params: std::collections::HashMap::from([("param".into(), 2.0)]), signal_col: "signal_B".into() },
375 ];
376
377 let out = run_walk_forward_optimize(df.lazy(), &zero_cost_config(), &wf, &variants, "total_return").unwrap();
378
379 assert_eq!(out.height(), 1);
380 let best_param = out.column("best_param").unwrap().f64().unwrap().get(0).unwrap();
381 assert_eq!(best_param, 1.0);
383 }
384
385 #[test]
386 fn test_wfo_opt_oos_uses_locked_param_not_reoptimized() {
387 let wf = WalkForwardConfig::new(20, 20);
388 let df = wfo_base_df(40);
389 let variants = vec![
390 crate::SweepVariant { params: std::collections::HashMap::from([("param".into(), 1.0)]), signal_col: "signal_A".into() },
391 crate::SweepVariant { params: std::collections::HashMap::from([("param".into(), 2.0)]), signal_col: "signal_B".into() },
392 ];
393 let out = run_walk_forward_optimize(df.lazy(), &zero_cost_config(), &wf, &variants, "total_return").unwrap();
394
395 let oos_metric = out.column("oos_metric").unwrap().f64().unwrap().get(0).unwrap();
396 assert!(oos_metric < 0.0);
398 }
399
400 #[test]
401 fn test_wfo_opt_overfit_flag_when_train_oos_diverge() {
402 let mut wf = WalkForwardConfig::new(20, 20);
403 wf.overfit_threshold = 0.0; let df = wfo_base_df(40);
405 let variants = vec![
406 crate::SweepVariant { params: std::collections::HashMap::from([("p".into(), 1.0)]), signal_col: "signal_A".into() },
407 ];
408 let out = run_walk_forward_optimize(df.lazy(), &zero_cost_config(), &wf, &variants, "total_return").unwrap();
409
410 let overfit = out.column("overfit_flag").unwrap().bool().unwrap().get(0).unwrap();
411 assert!(overfit);
413 }
414
415 #[test]
416 fn test_wfo_opt_fold_count_matches_walk_forward() {
417 let wf = WalkForwardConfig::new(20, 10);
418 let df = wfo_base_df(60);
419 let variants = vec![
420 crate::SweepVariant { params: std::collections::HashMap::from([("p".into(), 1.0)]), signal_col: "signal_A".into() },
421 ];
422 let mut cfg = zero_cost_config();
423 cfg.signal_col = "signal_A".into();
424 let out1 = run_walk_forward(df.clone().lazy(), &cfg, &wf).unwrap();
425 let out2 = run_walk_forward_optimize(df.lazy(), &zero_cost_config(), &wf, &variants, "total_return").unwrap();
426
427 assert_eq!(out1.height(), out2.height());
428 }
429}