1use std::collections::HashMap;
50
51use serde::{Deserialize, Serialize};
52
53use crate::models::chart::Candle;
54
55use super::config::BacktestConfig;
56use super::error::{BacktestError, Result};
57use super::optimizer::{GridSearch, OptimizationReport, ParamValue};
58use super::result::{BacktestResult, PerformanceMetrics};
59use super::strategy::Strategy;
60
61#[non_exhaustive]
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct WindowResult {
67 pub window: usize,
69 pub optimized_params: HashMap<String, ParamValue>,
71 pub in_sample: BacktestResult,
73 pub out_of_sample: BacktestResult,
75}
76
77#[non_exhaustive]
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct WalkForwardReport {
81 pub strategy_name: String,
83 pub windows: Vec<WindowResult>,
85 pub aggregate_metrics: PerformanceMetrics,
87 pub consistency_ratio: f64,
89 pub optimization_reports: Vec<OptimizationReport>,
91}
92
93#[non_exhaustive]
100#[derive(Debug, Clone)]
101pub struct WalkForwardConfig {
102 pub grid: GridSearch,
104 pub config: BacktestConfig,
106 pub in_sample_bars: usize,
108 pub out_of_sample_bars: usize,
110 pub step_bars: Option<usize>,
114}
115
116impl WalkForwardConfig {
117 pub fn new(grid: GridSearch, config: BacktestConfig) -> Self {
121 Self {
122 grid,
123 config,
124 in_sample_bars: 252,
125 out_of_sample_bars: 63,
126 step_bars: None,
127 }
128 }
129
130 pub fn in_sample_bars(mut self, bars: usize) -> Self {
132 self.in_sample_bars = bars;
133 self
134 }
135
136 pub fn out_of_sample_bars(mut self, bars: usize) -> Self {
138 self.out_of_sample_bars = bars;
139 self
140 }
141
142 pub fn step_bars(mut self, bars: usize) -> Self {
146 self.step_bars = Some(bars);
147 self
148 }
149
150 pub fn run<S, F>(
159 &self,
160 symbol: &str,
161 candles: &[Candle],
162 factory: F,
163 ) -> Result<WalkForwardReport>
164 where
165 S: Strategy + Clone + Send,
166 F: Fn(&HashMap<String, ParamValue>) -> S,
167 F: Send + Sync,
168 {
169 self.validate(candles.len())?;
170
171 let step = self.step_bars.unwrap_or(self.out_of_sample_bars);
172 let total_bars = self.in_sample_bars + self.out_of_sample_bars;
173
174 let mut windows: Vec<WindowResult> = Vec::new();
176 let mut opt_reports: Vec<OptimizationReport> = Vec::new();
177 let mut window_idx = 0;
178 let mut start = 0;
179
180 while start + total_bars <= candles.len() {
181 let is_end = start + self.in_sample_bars;
182 let oos_end = is_end + self.out_of_sample_bars;
183
184 let is_candles = &candles[start..is_end];
185 let oos_candles = &candles[is_end..oos_end];
186
187 let opt_report = self
189 .grid
190 .run(symbol, is_candles, &self.config, &factory)
191 .map_err(|e| {
192 BacktestError::invalid_param(
193 "walk_forward",
194 format!("window {window_idx} optimisation failed: {e}"),
195 )
196 })?;
197
198 let best_params = opt_report.best.params.clone();
199 let is_result = opt_report.best.result.clone();
200
201 let oos_strategy = factory(&best_params);
203 let oos_result = crate::backtesting::BacktestEngine::new(self.config.clone())
204 .run(symbol, oos_candles, oos_strategy)
205 .map_err(|e| {
206 BacktestError::invalid_param(
207 "walk_forward",
208 format!("window {window_idx} OOS run failed: {e}"),
209 )
210 })?;
211
212 windows.push(WindowResult {
213 window: window_idx,
214 optimized_params: best_params,
215 in_sample: is_result,
216 out_of_sample: oos_result,
217 });
218 opt_reports.push(opt_report);
219
220 start += step;
221 window_idx += 1;
222 }
223
224 if windows.is_empty() {
225 return Err(BacktestError::invalid_param(
226 "candles",
227 "not enough data for any walk-forward window",
228 ));
229 }
230
231 let strategy_name = windows[0].in_sample.strategy_name.clone();
232 let consistency_ratio = calculate_consistency_ratio(&windows);
233 let aggregate_metrics = aggregate_oos_metrics(
234 &windows,
235 self.config.risk_free_rate,
236 self.config.bars_per_year,
237 );
238
239 Ok(WalkForwardReport {
240 strategy_name,
241 windows,
242 aggregate_metrics,
243 consistency_ratio,
244 optimization_reports: opt_reports,
245 })
246 }
247
248 fn validate(&self, num_candles: usize) -> Result<()> {
250 if self.in_sample_bars == 0 {
251 return Err(BacktestError::invalid_param(
252 "in_sample_bars",
253 "must be greater than zero",
254 ));
255 }
256 if self.out_of_sample_bars == 0 {
257 return Err(BacktestError::invalid_param(
258 "out_of_sample_bars",
259 "must be greater than zero",
260 ));
261 }
262 let total_bars = self.in_sample_bars + self.out_of_sample_bars;
263 if num_candles < total_bars {
264 return Err(BacktestError::insufficient_data(total_bars, num_candles));
265 }
266 Ok(())
267 }
268}
269
270fn calculate_consistency_ratio(windows: &[WindowResult]) -> f64 {
274 if windows.is_empty() {
275 return 0.0;
276 }
277 let profitable = windows
278 .iter()
279 .filter(|w| w.out_of_sample.is_profitable())
280 .count();
281 profitable as f64 / windows.len() as f64
282}
283
284fn aggregate_oos_metrics(
289 windows: &[WindowResult],
290 risk_free_rate: f64,
291 bars_per_year: f64,
292) -> PerformanceMetrics {
293 use crate::backtesting::result::EquityPoint;
294
295 let all_trades: Vec<_> = windows
296 .iter()
297 .flat_map(|w| w.out_of_sample.trades.iter().cloned())
298 .collect();
299
300 let mut combined_equity: Vec<EquityPoint> = Vec::new();
305 let mut running_equity = windows[0].out_of_sample.initial_capital;
307
308 for (window_idx, window) in windows.iter().enumerate() {
309 let window_initial = window.out_of_sample.initial_capital;
310 if window_initial <= 0.0 {
311 continue;
312 }
313
314 for (point_idx, point) in window.out_of_sample.equity_curve.iter().enumerate() {
315 if window_idx > 0 && point_idx == 0 {
316 continue;
317 }
318
319 let scaled_equity = running_equity * (point.equity / window_initial);
320 combined_equity.push(EquityPoint {
321 timestamp: point.timestamp,
322 equity: scaled_equity,
323 drawdown_pct: 0.0,
324 });
325 }
326
327 if let Some(last) = combined_equity.last() {
328 running_equity = last.equity;
329 }
330 }
331
332 let mut peak = f64::NEG_INFINITY;
334 for point in &mut combined_equity {
335 peak = peak.max(point.equity);
336 point.drawdown_pct = if peak > 0.0 {
337 (peak - point.equity) / peak
338 } else {
339 0.0
340 };
341 }
342
343 let initial_capital = windows
345 .first()
346 .map(|w| w.out_of_sample.initial_capital)
347 .unwrap_or(10_000.0);
348
349 let total_signals: usize = windows.iter().map(|w| w.out_of_sample.signals.len()).sum();
350 let executed_signals: usize = windows
351 .iter()
352 .map(|w| {
353 w.out_of_sample
354 .signals
355 .iter()
356 .filter(|s| s.executed)
357 .count()
358 })
359 .sum();
360
361 PerformanceMetrics::calculate(
362 &all_trades,
363 &combined_equity,
364 initial_capital,
365 total_signals,
366 executed_signals,
367 risk_free_rate,
368 bars_per_year,
369 )
370}
371
372#[cfg(test)]
375mod tests {
376 use super::*;
377 use crate::backtesting::{
378 BacktestConfig, SmaCrossover,
379 optimizer::{OptimizeMetric, ParamRange},
380 };
381 use crate::models::chart::Candle;
382
383 fn make_candles(prices: &[f64]) -> Vec<Candle> {
384 prices
385 .iter()
386 .enumerate()
387 .map(|(i, &p)| Candle {
388 timestamp: i as i64,
389 open: p,
390 high: p * 1.01,
391 low: p * 0.99,
392 close: p,
393 volume: 1000,
394 adj_close: Some(p),
395 })
396 .collect()
397 }
398
399 fn trending_prices(n: usize) -> Vec<f64> {
400 (0..n).map(|i| 100.0 + i as f64 * 0.3).collect()
401 }
402
403 #[test]
404 fn test_walk_forward_basic() {
405 let prices = trending_prices(300);
407 let candles = make_candles(&prices);
408 let config = BacktestConfig::builder()
409 .commission_pct(0.0)
410 .slippage_pct(0.0)
411 .build()
412 .unwrap();
413
414 let grid = GridSearch::new()
415 .param("fast", ParamRange::int_range(3, 9, 3))
416 .param("slow", ParamRange::int_range(10, 20, 10))
417 .optimize_for(OptimizeMetric::TotalReturn);
418
419 let report = WalkForwardConfig::new(grid, config)
420 .in_sample_bars(200)
421 .out_of_sample_bars(100)
422 .run("TEST", &candles, |params| {
423 SmaCrossover::new(
424 params["fast"].as_int() as usize,
425 params["slow"].as_int() as usize,
426 )
427 })
428 .unwrap();
429
430 assert_eq!(report.windows.len(), 1);
431 assert_eq!(report.strategy_name, "SMA Crossover");
432 assert!(report.consistency_ratio >= 0.0);
433 assert!(report.consistency_ratio <= 1.0);
434 }
435
436 #[test]
437 fn test_walk_forward_multiple_windows() {
438 let prices = trending_prices(500);
440 let candles = make_candles(&prices);
441 let config = BacktestConfig::builder()
442 .commission_pct(0.0)
443 .slippage_pct(0.0)
444 .build()
445 .unwrap();
446
447 let grid = GridSearch::new()
448 .param("fast", ParamRange::int_range(3, 6, 3))
449 .param("slow", ParamRange::int_range(10, 10, 1))
450 .optimize_for(OptimizeMetric::TotalReturn);
451
452 let report = WalkForwardConfig::new(grid, config)
453 .in_sample_bars(200)
454 .out_of_sample_bars(100)
455 .step_bars(100)
456 .run("TEST", &candles, |params| {
457 SmaCrossover::new(
458 params["fast"].as_int() as usize,
459 params["slow"].as_int() as usize,
460 )
461 })
462 .unwrap();
463
464 assert!(report.windows.len() >= 2);
465 assert_eq!(report.optimization_reports.len(), report.windows.len());
466 }
467
468 #[test]
469 fn test_insufficient_data_errors() {
470 let candles = make_candles(&trending_prices(50));
471 let config = BacktestConfig::default();
472 let grid = GridSearch::new()
473 .param("fast", ParamRange::int_range(3, 6, 3))
474 .param("slow", ParamRange::int_range(10, 10, 1));
475
476 let result = WalkForwardConfig::new(grid, config)
477 .in_sample_bars(200) .out_of_sample_bars(100)
479 .run("TEST", &candles, |params| {
480 SmaCrossover::new(
481 params["fast"].as_int() as usize,
482 params["slow"].as_int() as usize,
483 )
484 });
485
486 assert!(result.is_err());
487 }
488
489 #[test]
490 fn test_consistency_ratio_all_profitable() {
491 let prices: Vec<f64> = (0..300).map(|i| 100.0 + i as f64).collect();
493 let candles = make_candles(&prices);
494 let config = BacktestConfig::builder()
495 .commission_pct(0.0)
496 .slippage_pct(0.0)
497 .build()
498 .unwrap();
499
500 let grid = GridSearch::new()
501 .param("fast", ParamRange::int_range(3, 3, 1))
502 .param("slow", ParamRange::int_range(10, 10, 1))
503 .optimize_for(OptimizeMetric::TotalReturn);
504
505 let report = WalkForwardConfig::new(grid, config)
506 .in_sample_bars(150)
507 .out_of_sample_bars(100)
508 .run("TEST", &candles, |params| {
509 SmaCrossover::new(
510 params["fast"].as_int() as usize,
511 params["slow"].as_int() as usize,
512 )
513 })
514 .unwrap();
515
516 assert!(report.consistency_ratio >= 0.0);
518 }
519
520 #[test]
521 fn test_aggregate_equity_timestamps_are_monotonic() {
522 let prices: Vec<f64> = (0..600).map(|i| 100.0 + (i as f64) * 0.5).collect();
525 let candles = make_candles(&prices);
526 let config = BacktestConfig::builder()
527 .commission_pct(0.0)
528 .slippage_pct(0.0)
529 .build()
530 .unwrap();
531
532 let grid = GridSearch::new()
533 .param("fast", ParamRange::int_range(3, 3, 1))
534 .param("slow", ParamRange::int_range(10, 10, 1))
535 .optimize_for(OptimizeMetric::TotalReturn);
536
537 let report = WalkForwardConfig::new(grid, config)
538 .in_sample_bars(100)
539 .out_of_sample_bars(50)
540 .run("TEST", &candles, |params| {
541 SmaCrossover::new(
542 params["fast"].as_int() as usize,
543 params["slow"].as_int() as usize,
544 )
545 })
546 .unwrap();
547
548 let curve = &report.aggregate_metrics;
550 assert!(
552 report.windows.len() >= 2,
553 "Expected multiple windows for timestamp test"
554 );
555
556 let timestamps: Vec<i64> = report
558 .windows
559 .iter()
560 .flat_map(|w| w.out_of_sample.equity_curve.iter().map(|ep| ep.timestamp))
561 .collect();
562
563 for window in &report.windows {
565 let ts: Vec<i64> = window
566 .out_of_sample
567 .equity_curve
568 .iter()
569 .map(|ep| ep.timestamp)
570 .collect();
571 for pair in ts.windows(2) {
572 assert!(
573 pair[0] < pair[1],
574 "Equity curve timestamps not strictly increasing within window: {} >= {}",
575 pair[0],
576 pair[1]
577 );
578 }
579 }
580
581 let _ = curve;
583 let _ = timestamps;
584 }
585
586 #[test]
587 fn test_aggregate_oos_equity_timestamps_are_gapless_across_windows() {
588 let prices: Vec<f64> = (0..600).map(|i| 100.0 + (i as f64) * 0.5).collect();
595 let candles = make_candles(&prices);
596 let config = BacktestConfig::builder()
597 .commission_pct(0.0)
598 .slippage_pct(0.0)
599 .build()
600 .unwrap();
601
602 let grid = GridSearch::new()
603 .param("fast", ParamRange::int_range(3, 3, 1))
604 .param("slow", ParamRange::int_range(10, 10, 1))
605 .optimize_for(OptimizeMetric::TotalReturn);
606
607 let report = WalkForwardConfig::new(grid, config)
608 .in_sample_bars(100)
609 .out_of_sample_bars(50)
610 .run("TEST", &candles, |params| {
611 SmaCrossover::new(
612 params["fast"].as_int() as usize,
613 params["slow"].as_int() as usize,
614 )
615 })
616 .unwrap();
617
618 assert!(
619 report.windows.len() >= 2,
620 "Need at least 2 OOS windows for this test"
621 );
622
623 let combined_ts: Vec<i64> = report
626 .windows
627 .iter()
628 .enumerate()
629 .flat_map(|(wi, w)| {
630 w.out_of_sample
631 .equity_curve
632 .iter()
633 .enumerate()
634 .filter(move |&(pi, _)| !(wi > 0 && pi == 0))
635 .map(|(_, ep)| ep.timestamp)
636 })
637 .collect();
638
639 for pair in combined_ts.windows(2) {
640 assert!(
641 pair[0] < pair[1],
642 "Combined equity curve timestamps not strictly increasing: {} >= {}",
643 pair[0],
644 pair[1]
645 );
646 }
647
648 let expected_first = report
651 .windows
652 .first()
653 .and_then(|w| w.out_of_sample.equity_curve.first())
654 .map(|ep| ep.timestamp)
655 .unwrap_or(0);
656 assert_eq!(
657 combined_ts.first().copied().unwrap_or(-1),
658 expected_first,
659 "First combined timestamp should equal the first OOS equity point timestamp"
660 );
661 }
662}