use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::models::chart::Candle;
use super::config::BacktestConfig;
use super::error::{BacktestError, Result};
use super::optimizer::{GridSearch, OptimizationReport, ParamValue};
use super::result::{BacktestResult, PerformanceMetrics};
use super::strategy::Strategy;
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WindowResult {
pub window: usize,
pub optimized_params: HashMap<String, ParamValue>,
pub in_sample: BacktestResult,
pub out_of_sample: BacktestResult,
}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WalkForwardReport {
pub strategy_name: String,
pub windows: Vec<WindowResult>,
pub aggregate_metrics: PerformanceMetrics,
pub consistency_ratio: f64,
pub optimization_reports: Vec<OptimizationReport>,
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct WalkForwardConfig {
pub grid: GridSearch,
pub config: BacktestConfig,
pub in_sample_bars: usize,
pub out_of_sample_bars: usize,
pub step_bars: Option<usize>,
}
impl WalkForwardConfig {
pub fn new(grid: GridSearch, config: BacktestConfig) -> Self {
Self {
grid,
config,
in_sample_bars: 252,
out_of_sample_bars: 63,
step_bars: None,
}
}
pub fn in_sample_bars(mut self, bars: usize) -> Self {
self.in_sample_bars = bars;
self
}
pub fn out_of_sample_bars(mut self, bars: usize) -> Self {
self.out_of_sample_bars = bars;
self
}
pub fn step_bars(mut self, bars: usize) -> Self {
self.step_bars = Some(bars);
self
}
pub fn run<S, F>(
&self,
symbol: &str,
candles: &[Candle],
factory: F,
) -> Result<WalkForwardReport>
where
S: Strategy + Clone + Send,
F: Fn(&HashMap<String, ParamValue>) -> S,
F: Send + Sync,
{
self.validate(candles.len())?;
let step = self.step_bars.unwrap_or(self.out_of_sample_bars);
let total_bars = self.in_sample_bars + self.out_of_sample_bars;
let mut windows: Vec<WindowResult> = Vec::new();
let mut opt_reports: Vec<OptimizationReport> = Vec::new();
let mut window_idx = 0;
let mut start = 0;
while start + total_bars <= candles.len() {
let is_end = start + self.in_sample_bars;
let oos_end = is_end + self.out_of_sample_bars;
let is_candles = &candles[start..is_end];
let oos_candles = &candles[is_end..oos_end];
let opt_report = self
.grid
.run(symbol, is_candles, &self.config, &factory)
.map_err(|e| {
BacktestError::invalid_param(
"walk_forward",
format!("window {window_idx} optimisation failed: {e}"),
)
})?;
let best_params = opt_report.best.params.clone();
let is_result = opt_report.best.result.clone();
let oos_strategy = factory(&best_params);
let oos_result = crate::backtesting::BacktestEngine::new(self.config.clone())
.run(symbol, oos_candles, oos_strategy)
.map_err(|e| {
BacktestError::invalid_param(
"walk_forward",
format!("window {window_idx} OOS run failed: {e}"),
)
})?;
windows.push(WindowResult {
window: window_idx,
optimized_params: best_params,
in_sample: is_result,
out_of_sample: oos_result,
});
opt_reports.push(opt_report);
start += step;
window_idx += 1;
}
if windows.is_empty() {
return Err(BacktestError::invalid_param(
"candles",
"not enough data for any walk-forward window",
));
}
let strategy_name = windows[0].in_sample.strategy_name.clone();
let consistency_ratio = calculate_consistency_ratio(&windows);
let aggregate_metrics = aggregate_oos_metrics(
&windows,
self.config.risk_free_rate,
self.config.bars_per_year,
);
Ok(WalkForwardReport {
strategy_name,
windows,
aggregate_metrics,
consistency_ratio,
optimization_reports: opt_reports,
})
}
fn validate(&self, num_candles: usize) -> Result<()> {
if self.in_sample_bars == 0 {
return Err(BacktestError::invalid_param(
"in_sample_bars",
"must be greater than zero",
));
}
if self.out_of_sample_bars == 0 {
return Err(BacktestError::invalid_param(
"out_of_sample_bars",
"must be greater than zero",
));
}
let total_bars = self.in_sample_bars + self.out_of_sample_bars;
if num_candles < total_bars {
return Err(BacktestError::insufficient_data(total_bars, num_candles));
}
Ok(())
}
}
fn calculate_consistency_ratio(windows: &[WindowResult]) -> f64 {
if windows.is_empty() {
return 0.0;
}
let profitable = windows
.iter()
.filter(|w| w.out_of_sample.is_profitable())
.count();
profitable as f64 / windows.len() as f64
}
fn aggregate_oos_metrics(
windows: &[WindowResult],
risk_free_rate: f64,
bars_per_year: f64,
) -> PerformanceMetrics {
use crate::backtesting::result::EquityPoint;
let all_trades: Vec<_> = windows
.iter()
.flat_map(|w| w.out_of_sample.trades.iter().cloned())
.collect();
let mut combined_equity: Vec<EquityPoint> = Vec::new();
let mut running_equity = windows[0].out_of_sample.initial_capital;
for (window_idx, window) in windows.iter().enumerate() {
let window_initial = window.out_of_sample.initial_capital;
if window_initial <= 0.0 {
continue;
}
for (point_idx, point) in window.out_of_sample.equity_curve.iter().enumerate() {
if window_idx > 0 && point_idx == 0 {
continue;
}
let scaled_equity = running_equity * (point.equity / window_initial);
combined_equity.push(EquityPoint {
timestamp: point.timestamp,
equity: scaled_equity,
drawdown_pct: 0.0,
});
}
if let Some(last) = combined_equity.last() {
running_equity = last.equity;
}
}
let mut peak = f64::NEG_INFINITY;
for point in &mut combined_equity {
peak = peak.max(point.equity);
point.drawdown_pct = if peak > 0.0 {
(peak - point.equity) / peak
} else {
0.0
};
}
let initial_capital = windows
.first()
.map(|w| w.out_of_sample.initial_capital)
.unwrap_or(10_000.0);
let total_signals: usize = windows.iter().map(|w| w.out_of_sample.signals.len()).sum();
let executed_signals: usize = windows
.iter()
.map(|w| {
w.out_of_sample
.signals
.iter()
.filter(|s| s.executed)
.count()
})
.sum();
PerformanceMetrics::calculate(
&all_trades,
&combined_equity,
initial_capital,
total_signals,
executed_signals,
risk_free_rate,
bars_per_year,
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backtesting::{
BacktestConfig, SmaCrossover,
optimizer::{OptimizeMetric, ParamRange},
};
use crate::models::chart::Candle;
fn make_candles(prices: &[f64]) -> Vec<Candle> {
prices
.iter()
.enumerate()
.map(|(i, &p)| Candle {
timestamp: i as i64,
open: p,
high: p * 1.01,
low: p * 0.99,
close: p,
volume: 1000,
adj_close: Some(p),
})
.collect()
}
fn trending_prices(n: usize) -> Vec<f64> {
(0..n).map(|i| 100.0 + i as f64 * 0.3).collect()
}
#[test]
fn test_walk_forward_basic() {
let prices = trending_prices(300);
let candles = make_candles(&prices);
let config = BacktestConfig::builder()
.commission_pct(0.0)
.slippage_pct(0.0)
.build()
.unwrap();
let grid = GridSearch::new()
.param("fast", ParamRange::int_range(3, 9, 3))
.param("slow", ParamRange::int_range(10, 20, 10))
.optimize_for(OptimizeMetric::TotalReturn);
let report = WalkForwardConfig::new(grid, config)
.in_sample_bars(200)
.out_of_sample_bars(100)
.run("TEST", &candles, |params| {
SmaCrossover::new(
params["fast"].as_int() as usize,
params["slow"].as_int() as usize,
)
})
.unwrap();
assert_eq!(report.windows.len(), 1);
assert_eq!(report.strategy_name, "SMA Crossover");
assert!(report.consistency_ratio >= 0.0);
assert!(report.consistency_ratio <= 1.0);
}
#[test]
fn test_walk_forward_multiple_windows() {
let prices = trending_prices(500);
let candles = make_candles(&prices);
let config = BacktestConfig::builder()
.commission_pct(0.0)
.slippage_pct(0.0)
.build()
.unwrap();
let grid = GridSearch::new()
.param("fast", ParamRange::int_range(3, 6, 3))
.param("slow", ParamRange::int_range(10, 10, 1))
.optimize_for(OptimizeMetric::TotalReturn);
let report = WalkForwardConfig::new(grid, config)
.in_sample_bars(200)
.out_of_sample_bars(100)
.step_bars(100)
.run("TEST", &candles, |params| {
SmaCrossover::new(
params["fast"].as_int() as usize,
params["slow"].as_int() as usize,
)
})
.unwrap();
assert!(report.windows.len() >= 2);
assert_eq!(report.optimization_reports.len(), report.windows.len());
}
#[test]
fn test_insufficient_data_errors() {
let candles = make_candles(&trending_prices(50));
let config = BacktestConfig::default();
let grid = GridSearch::new()
.param("fast", ParamRange::int_range(3, 6, 3))
.param("slow", ParamRange::int_range(10, 10, 1));
let result = WalkForwardConfig::new(grid, config)
.in_sample_bars(200) .out_of_sample_bars(100)
.run("TEST", &candles, |params| {
SmaCrossover::new(
params["fast"].as_int() as usize,
params["slow"].as_int() as usize,
)
});
assert!(result.is_err());
}
#[test]
fn test_consistency_ratio_all_profitable() {
let prices: Vec<f64> = (0..300).map(|i| 100.0 + i as f64).collect();
let candles = make_candles(&prices);
let config = BacktestConfig::builder()
.commission_pct(0.0)
.slippage_pct(0.0)
.build()
.unwrap();
let grid = GridSearch::new()
.param("fast", ParamRange::int_range(3, 3, 1))
.param("slow", ParamRange::int_range(10, 10, 1))
.optimize_for(OptimizeMetric::TotalReturn);
let report = WalkForwardConfig::new(grid, config)
.in_sample_bars(150)
.out_of_sample_bars(100)
.run("TEST", &candles, |params| {
SmaCrossover::new(
params["fast"].as_int() as usize,
params["slow"].as_int() as usize,
)
})
.unwrap();
assert!(report.consistency_ratio >= 0.0);
}
#[test]
fn test_aggregate_equity_timestamps_are_monotonic() {
let prices: Vec<f64> = (0..600).map(|i| 100.0 + (i as f64) * 0.5).collect();
let candles = make_candles(&prices);
let config = BacktestConfig::builder()
.commission_pct(0.0)
.slippage_pct(0.0)
.build()
.unwrap();
let grid = GridSearch::new()
.param("fast", ParamRange::int_range(3, 3, 1))
.param("slow", ParamRange::int_range(10, 10, 1))
.optimize_for(OptimizeMetric::TotalReturn);
let report = WalkForwardConfig::new(grid, config)
.in_sample_bars(100)
.out_of_sample_bars(50)
.run("TEST", &candles, |params| {
SmaCrossover::new(
params["fast"].as_int() as usize,
params["slow"].as_int() as usize,
)
})
.unwrap();
let curve = &report.aggregate_metrics;
assert!(
report.windows.len() >= 2,
"Expected multiple windows for timestamp test"
);
let timestamps: Vec<i64> = report
.windows
.iter()
.flat_map(|w| w.out_of_sample.equity_curve.iter().map(|ep| ep.timestamp))
.collect();
for window in &report.windows {
let ts: Vec<i64> = window
.out_of_sample
.equity_curve
.iter()
.map(|ep| ep.timestamp)
.collect();
for pair in ts.windows(2) {
assert!(
pair[0] < pair[1],
"Equity curve timestamps not strictly increasing within window: {} >= {}",
pair[0],
pair[1]
);
}
}
let _ = curve;
let _ = timestamps;
}
#[test]
fn test_aggregate_oos_equity_timestamps_are_gapless_across_windows() {
let prices: Vec<f64> = (0..600).map(|i| 100.0 + (i as f64) * 0.5).collect();
let candles = make_candles(&prices);
let config = BacktestConfig::builder()
.commission_pct(0.0)
.slippage_pct(0.0)
.build()
.unwrap();
let grid = GridSearch::new()
.param("fast", ParamRange::int_range(3, 3, 1))
.param("slow", ParamRange::int_range(10, 10, 1))
.optimize_for(OptimizeMetric::TotalReturn);
let report = WalkForwardConfig::new(grid, config)
.in_sample_bars(100)
.out_of_sample_bars(50)
.run("TEST", &candles, |params| {
SmaCrossover::new(
params["fast"].as_int() as usize,
params["slow"].as_int() as usize,
)
})
.unwrap();
assert!(
report.windows.len() >= 2,
"Need at least 2 OOS windows for this test"
);
let combined_ts: Vec<i64> = report
.windows
.iter()
.enumerate()
.flat_map(|(wi, w)| {
w.out_of_sample
.equity_curve
.iter()
.enumerate()
.filter(move |&(pi, _)| !(wi > 0 && pi == 0))
.map(|(_, ep)| ep.timestamp)
})
.collect();
for pair in combined_ts.windows(2) {
assert!(
pair[0] < pair[1],
"Combined equity curve timestamps not strictly increasing: {} >= {}",
pair[0],
pair[1]
);
}
let expected_first = report
.windows
.first()
.and_then(|w| w.out_of_sample.equity_curve.first())
.map(|ep| ep.timestamp)
.unwrap_or(0);
assert_eq!(
combined_ts.first().copied().unwrap_or(-1),
expected_first,
"First combined timestamp should equal the first OOS equity point timestamp"
);
}
}