1use std::collections::HashMap;
50
51use serde::{Deserialize, Serialize};
52
53use crate::models::chart::Candle;
54
55use super::config::BacktestConfig;
56use super::error::{BacktestError, Result};
57use super::optimizer::{GridSearch, OptimizationReport, ParamValue};
58use super::result::{BacktestResult, PerformanceMetrics};
59use super::strategy::Strategy;
60
61#[non_exhaustive]
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct WindowResult {
67 pub window: usize,
69 pub optimized_params: HashMap<String, ParamValue>,
71 pub in_sample: BacktestResult,
73 pub out_of_sample: BacktestResult,
75}
76
77#[non_exhaustive]
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct WalkForwardReport {
81 pub strategy_name: String,
83 pub windows: Vec<WindowResult>,
85 pub aggregate_metrics: PerformanceMetrics,
87 pub consistency_ratio: f64,
89 pub optimization_reports: Vec<OptimizationReport>,
91}
92
93#[non_exhaustive]
100#[derive(Debug, Clone)]
101pub struct WalkForwardConfig {
102 pub grid: GridSearch,
104 pub config: BacktestConfig,
106 pub in_sample_bars: usize,
108 pub out_of_sample_bars: usize,
110 pub step_bars: Option<usize>,
114}
115
116impl WalkForwardConfig {
117 pub fn new(grid: GridSearch, config: BacktestConfig) -> Self {
121 Self {
122 grid,
123 config,
124 in_sample_bars: 252,
125 out_of_sample_bars: 63,
126 step_bars: None,
127 }
128 }
129
130 pub fn in_sample_bars(mut self, bars: usize) -> Self {
132 self.in_sample_bars = bars;
133 self
134 }
135
136 pub fn out_of_sample_bars(mut self, bars: usize) -> Self {
138 self.out_of_sample_bars = bars;
139 self
140 }
141
142 pub fn step_bars(mut self, bars: usize) -> Self {
146 self.step_bars = Some(bars);
147 self
148 }
149
150 pub fn run<S, F>(
159 &self,
160 symbol: &str,
161 candles: &[Candle],
162 factory: F,
163 ) -> Result<WalkForwardReport>
164 where
165 S: Strategy + Clone + Send,
166 F: Fn(&HashMap<String, ParamValue>) -> S,
167 F: Send + Sync,
168 {
169 self.validate(candles.len())?;
170
171 let step = self.step_bars.unwrap_or(self.out_of_sample_bars);
172 let total_bars = self.in_sample_bars + self.out_of_sample_bars;
173
174 let mut windows: Vec<WindowResult> = Vec::new();
176 let mut opt_reports: Vec<OptimizationReport> = Vec::new();
177 let mut window_idx = 0;
178 let mut start = 0;
179
180 while start + total_bars <= candles.len() {
181 let is_end = start + self.in_sample_bars;
182 let oos_end = is_end + self.out_of_sample_bars;
183
184 let is_candles = &candles[start..is_end];
185 let oos_candles = &candles[is_end..oos_end];
186
187 let opt_report = self
189 .grid
190 .run(symbol, is_candles, &self.config, &factory)
191 .map_err(|e| {
192 BacktestError::invalid_param(
193 "walk_forward",
194 format!("window {window_idx} optimisation failed: {e}"),
195 )
196 })?;
197
198 let best_params = opt_report.best.params.clone();
199 let is_result = opt_report.best.result.clone();
200
201 let oos_strategy = factory(&best_params);
203 let oos_result = crate::backtesting::BacktestEngine::new(self.config.clone())
204 .run(symbol, oos_candles, oos_strategy)
205 .map_err(|e| {
206 BacktestError::invalid_param(
207 "walk_forward",
208 format!("window {window_idx} OOS run failed: {e}"),
209 )
210 })?;
211
212 windows.push(WindowResult {
213 window: window_idx,
214 optimized_params: best_params,
215 in_sample: is_result,
216 out_of_sample: oos_result,
217 });
218 opt_reports.push(opt_report);
219
220 start += step;
221 window_idx += 1;
222 }
223
224 if windows.is_empty() {
225 return Err(BacktestError::invalid_param(
226 "candles",
227 "not enough data for any walk-forward window",
228 ));
229 }
230
231 let strategy_name = windows[0].in_sample.strategy_name.clone();
232 let consistency_ratio = calculate_consistency_ratio(&windows);
233 let aggregate_metrics = aggregate_oos_metrics(
234 &windows,
235 self.config.risk_free_rate,
236 self.config.bars_per_year,
237 );
238
239 Ok(WalkForwardReport {
240 strategy_name,
241 windows,
242 aggregate_metrics,
243 consistency_ratio,
244 optimization_reports: opt_reports,
245 })
246 }
247
248 fn validate(&self, num_candles: usize) -> Result<()> {
250 if self.in_sample_bars == 0 {
251 return Err(BacktestError::invalid_param(
252 "in_sample_bars",
253 "must be greater than zero",
254 ));
255 }
256 if self.out_of_sample_bars == 0 {
257 return Err(BacktestError::invalid_param(
258 "out_of_sample_bars",
259 "must be greater than zero",
260 ));
261 }
262 let total_bars = self.in_sample_bars + self.out_of_sample_bars;
263 if num_candles < total_bars {
264 return Err(BacktestError::insufficient_data(total_bars, num_candles));
265 }
266 Ok(())
267 }
268}
269
270fn calculate_consistency_ratio(windows: &[WindowResult]) -> f64 {
274 if windows.is_empty() {
275 return 0.0;
276 }
277 let profitable = windows
278 .iter()
279 .filter(|w| w.out_of_sample.is_profitable())
280 .count();
281 profitable as f64 / windows.len() as f64
282}
283
284fn aggregate_oos_metrics(
289 windows: &[WindowResult],
290 risk_free_rate: f64,
291 bars_per_year: f64,
292) -> PerformanceMetrics {
293 use crate::backtesting::result::EquityPoint;
294
295 let all_trades: Vec<_> = windows
296 .iter()
297 .flat_map(|w| w.out_of_sample.trades.iter().cloned())
298 .collect();
299
300 let mut combined_equity: Vec<EquityPoint> = Vec::new();
305 let mut running_equity = windows[0].out_of_sample.initial_capital;
307
308 for (window_idx, window) in windows.iter().enumerate() {
309 let window_initial = window.out_of_sample.initial_capital;
310 if window_initial <= 0.0 {
311 continue;
312 }
313
314 for (point_idx, point) in window.out_of_sample.equity_curve.iter().enumerate() {
315 if window_idx > 0 && point_idx == 0 {
316 continue;
317 }
318
319 let scaled_equity = running_equity * (point.equity / window_initial);
320 combined_equity.push(EquityPoint {
321 timestamp: point.timestamp,
322 equity: scaled_equity,
323 drawdown_pct: 0.0,
324 });
325 }
326
327 if let Some(last) = combined_equity.last() {
328 running_equity = last.equity;
329 }
330 }
331
332 let mut peak = f64::NEG_INFINITY;
334 for point in &mut combined_equity {
335 peak = peak.max(point.equity);
336 point.drawdown_pct = if peak > 0.0 {
337 (peak - point.equity) / peak
338 } else {
339 0.0
340 };
341 }
342
343 let initial_capital = windows
345 .first()
346 .map(|w| w.out_of_sample.initial_capital)
347 .unwrap_or(10_000.0);
348
349 let total_signals: usize = windows.iter().map(|w| w.out_of_sample.signals.len()).sum();
350 let executed_signals: usize = windows
351 .iter()
352 .map(|w| {
353 w.out_of_sample
354 .signals
355 .iter()
356 .filter(|s| s.executed)
357 .count()
358 })
359 .sum();
360
361 PerformanceMetrics::calculate(
362 &all_trades,
363 &combined_equity,
364 initial_capital,
365 total_signals,
366 executed_signals,
367 risk_free_rate,
368 bars_per_year,
369 )
370}
371
372#[cfg(test)]
375mod tests {
376 use super::*;
377 use crate::backtesting::{
378 BacktestConfig, SmaCrossover,
379 optimizer::{OptimizeMetric, ParamRange},
380 };
381 use crate::models::chart::Candle;
382
383 fn make_candles(prices: &[f64]) -> Vec<Candle> {
384 prices
385 .iter()
386 .enumerate()
387 .map(|(i, &p)| Candle {
388 timestamp: i as i64,
389 open: p,
390 high: p * 1.01,
391 low: p * 0.99,
392 close: p,
393 volume: 1000,
394 adj_close: Some(p),
395 provider_id: None,
396 })
397 .collect()
398 }
399
400 fn trending_prices(n: usize) -> Vec<f64> {
401 (0..n).map(|i| 100.0 + i as f64 * 0.3).collect()
402 }
403
404 #[test]
405 fn test_walk_forward_basic() {
406 let prices = trending_prices(300);
408 let candles = make_candles(&prices);
409 let config = BacktestConfig::builder()
410 .commission_pct(0.0)
411 .slippage_pct(0.0)
412 .build()
413 .unwrap();
414
415 let grid = GridSearch::new()
416 .param("fast", ParamRange::int_range(3, 9, 3))
417 .param("slow", ParamRange::int_range(10, 20, 10))
418 .optimize_for(OptimizeMetric::TotalReturn);
419
420 let report = WalkForwardConfig::new(grid, config)
421 .in_sample_bars(200)
422 .out_of_sample_bars(100)
423 .run("TEST", &candles, |params| {
424 SmaCrossover::new(
425 params["fast"].as_int() as usize,
426 params["slow"].as_int() as usize,
427 )
428 })
429 .unwrap();
430
431 assert_eq!(report.windows.len(), 1);
432 assert_eq!(report.strategy_name, "SMA Crossover");
433 assert!(report.consistency_ratio >= 0.0);
434 assert!(report.consistency_ratio <= 1.0);
435 }
436
437 #[test]
438 fn test_walk_forward_multiple_windows() {
439 let prices = trending_prices(500);
441 let candles = make_candles(&prices);
442 let config = BacktestConfig::builder()
443 .commission_pct(0.0)
444 .slippage_pct(0.0)
445 .build()
446 .unwrap();
447
448 let grid = GridSearch::new()
449 .param("fast", ParamRange::int_range(3, 6, 3))
450 .param("slow", ParamRange::int_range(10, 10, 1))
451 .optimize_for(OptimizeMetric::TotalReturn);
452
453 let report = WalkForwardConfig::new(grid, config)
454 .in_sample_bars(200)
455 .out_of_sample_bars(100)
456 .step_bars(100)
457 .run("TEST", &candles, |params| {
458 SmaCrossover::new(
459 params["fast"].as_int() as usize,
460 params["slow"].as_int() as usize,
461 )
462 })
463 .unwrap();
464
465 assert!(report.windows.len() >= 2);
466 assert_eq!(report.optimization_reports.len(), report.windows.len());
467 }
468
469 #[test]
470 fn test_insufficient_data_errors() {
471 let candles = make_candles(&trending_prices(50));
472 let config = BacktestConfig::default();
473 let grid = GridSearch::new()
474 .param("fast", ParamRange::int_range(3, 6, 3))
475 .param("slow", ParamRange::int_range(10, 10, 1));
476
477 let result = WalkForwardConfig::new(grid, config)
478 .in_sample_bars(200) .out_of_sample_bars(100)
480 .run("TEST", &candles, |params| {
481 SmaCrossover::new(
482 params["fast"].as_int() as usize,
483 params["slow"].as_int() as usize,
484 )
485 });
486
487 assert!(result.is_err());
488 }
489
490 #[test]
491 fn test_consistency_ratio_all_profitable() {
492 let prices: Vec<f64> = (0..300).map(|i| 100.0 + i as f64).collect();
494 let candles = make_candles(&prices);
495 let config = BacktestConfig::builder()
496 .commission_pct(0.0)
497 .slippage_pct(0.0)
498 .build()
499 .unwrap();
500
501 let grid = GridSearch::new()
502 .param("fast", ParamRange::int_range(3, 3, 1))
503 .param("slow", ParamRange::int_range(10, 10, 1))
504 .optimize_for(OptimizeMetric::TotalReturn);
505
506 let report = WalkForwardConfig::new(grid, config)
507 .in_sample_bars(150)
508 .out_of_sample_bars(100)
509 .run("TEST", &candles, |params| {
510 SmaCrossover::new(
511 params["fast"].as_int() as usize,
512 params["slow"].as_int() as usize,
513 )
514 })
515 .unwrap();
516
517 assert!(report.consistency_ratio >= 0.0);
519 }
520
521 #[test]
522 fn test_aggregate_equity_timestamps_are_monotonic() {
523 let prices: Vec<f64> = (0..600).map(|i| 100.0 + (i as f64) * 0.5).collect();
526 let candles = make_candles(&prices);
527 let config = BacktestConfig::builder()
528 .commission_pct(0.0)
529 .slippage_pct(0.0)
530 .build()
531 .unwrap();
532
533 let grid = GridSearch::new()
534 .param("fast", ParamRange::int_range(3, 3, 1))
535 .param("slow", ParamRange::int_range(10, 10, 1))
536 .optimize_for(OptimizeMetric::TotalReturn);
537
538 let report = WalkForwardConfig::new(grid, config)
539 .in_sample_bars(100)
540 .out_of_sample_bars(50)
541 .run("TEST", &candles, |params| {
542 SmaCrossover::new(
543 params["fast"].as_int() as usize,
544 params["slow"].as_int() as usize,
545 )
546 })
547 .unwrap();
548
549 let curve = &report.aggregate_metrics;
551 assert!(
553 report.windows.len() >= 2,
554 "Expected multiple windows for timestamp test"
555 );
556
557 let timestamps: Vec<i64> = report
559 .windows
560 .iter()
561 .flat_map(|w| w.out_of_sample.equity_curve.iter().map(|ep| ep.timestamp))
562 .collect();
563
564 for window in &report.windows {
566 let ts: Vec<i64> = window
567 .out_of_sample
568 .equity_curve
569 .iter()
570 .map(|ep| ep.timestamp)
571 .collect();
572 for pair in ts.windows(2) {
573 assert!(
574 pair[0] < pair[1],
575 "Equity curve timestamps not strictly increasing within window: {} >= {}",
576 pair[0],
577 pair[1]
578 );
579 }
580 }
581
582 let _ = curve;
584 let _ = timestamps;
585 }
586
587 #[test]
588 fn test_aggregate_oos_equity_timestamps_are_gapless_across_windows() {
589 let prices: Vec<f64> = (0..600).map(|i| 100.0 + (i as f64) * 0.5).collect();
596 let candles = make_candles(&prices);
597 let config = BacktestConfig::builder()
598 .commission_pct(0.0)
599 .slippage_pct(0.0)
600 .build()
601 .unwrap();
602
603 let grid = GridSearch::new()
604 .param("fast", ParamRange::int_range(3, 3, 1))
605 .param("slow", ParamRange::int_range(10, 10, 1))
606 .optimize_for(OptimizeMetric::TotalReturn);
607
608 let report = WalkForwardConfig::new(grid, config)
609 .in_sample_bars(100)
610 .out_of_sample_bars(50)
611 .run("TEST", &candles, |params| {
612 SmaCrossover::new(
613 params["fast"].as_int() as usize,
614 params["slow"].as_int() as usize,
615 )
616 })
617 .unwrap();
618
619 assert!(
620 report.windows.len() >= 2,
621 "Need at least 2 OOS windows for this test"
622 );
623
624 let combined_ts: Vec<i64> = report
627 .windows
628 .iter()
629 .enumerate()
630 .flat_map(|(wi, w)| {
631 w.out_of_sample
632 .equity_curve
633 .iter()
634 .enumerate()
635 .filter(move |&(pi, _)| !(wi > 0 && pi == 0))
636 .map(|(_, ep)| ep.timestamp)
637 })
638 .collect();
639
640 for pair in combined_ts.windows(2) {
641 assert!(
642 pair[0] < pair[1],
643 "Combined equity curve timestamps not strictly increasing: {} >= {}",
644 pair[0],
645 pair[1]
646 );
647 }
648
649 let expected_first = report
652 .windows
653 .first()
654 .and_then(|w| w.out_of_sample.equity_curve.first())
655 .map(|ep| ep.timestamp)
656 .unwrap_or(0);
657 assert_eq!(
658 combined_ts.first().copied().unwrap_or(-1),
659 expected_first,
660 "First combined timestamp should equal the first OOS equity point timestamp"
661 );
662 }
663}