Skip to main content

quantwave_backtest/
sweep.rs

1//! Parameter sweep helper (quantwave-cr6v.12).
2//!
3//! Runs one backtest per variant and returns a Polars DataFrame with param columns
4//! plus [`PerformanceMetrics`] columns (vectorbt / RaptorBT `batch_spread` pattern).
5
6use crate::{BacktestConfig, BacktestEngine, BacktestError, PerformanceMetrics};
7use polars::prelude::*;
8use std::collections::HashMap;
9
10/// One grid point: parameter values and the signal column to backtest.
11#[derive(Debug, Clone, PartialEq)]
12pub struct SweepVariant {
13    pub params: HashMap<String, f64>,
14    pub signal_col: String,
15}
16
17/// Build variants for a single-parameter grid (e.g. `hurst_period` → signal columns).
18pub fn single_param_variants(
19    param_name: impl Into<String>,
20    param_values: &[f64],
21    signal_cols: &[impl AsRef<str>],
22) -> Result<Vec<SweepVariant>, BacktestError> {
23    if param_values.len() != signal_cols.len() {
24        return Err(BacktestError::InvalidInput(format!(
25            "param_values len {} != signal_cols len {}",
26            param_values.len(),
27            signal_cols.len()
28        )));
29    }
30    if param_values.is_empty() {
31        return Err(BacktestError::InvalidInput(
32            "sweep requires at least one variant".into(),
33        ));
34    }
35
36    let name = param_name.into();
37    Ok(param_values
38        .iter()
39        .zip(signal_cols.iter())
40        .map(|(&value, col)| SweepVariant {
41            params: HashMap::from([(name.clone(), value)]),
42            signal_col: col.as_ref().to_string(),
43        })
44        .collect())
45}
46
47/// Run backtests for each variant and return a param × metrics DataFrame.
48pub fn run_param_sweep(
49    lf: LazyFrame,
50    variants: &[SweepVariant],
51    base_config: &BacktestConfig,
52) -> Result<DataFrame, BacktestError> {
53    if variants.is_empty() {
54        return Err(BacktestError::InvalidInput(
55            "sweep requires at least one variant".into(),
56        ));
57    }
58
59    let param_keys = sorted_param_keys(variants);
60    let mut param_cols: HashMap<String, Vec<f64>> =
61        param_keys.iter().map(|k| (k.clone(), Vec::new())).collect();
62    let mut metric_cols: HashMap<&'static str, Vec<f64>> = PerformanceMetrics::column_names()
63        .iter()
64        .map(|&name| (name, Vec::new()))
65        .collect();
66
67    for variant in variants {
68        for key in &param_keys {
69            let value = variant.params.get(key).copied().ok_or_else(|| {
70                BacktestError::InvalidInput(format!(
71                    "variant missing param key '{key}' (expected keys: {param_keys:?})"
72                ))
73            })?;
74            param_cols.get_mut(key).unwrap().push(value);
75        }
76
77        let mut config = base_config.clone();
78        config.signal_col = variant.signal_col.clone();
79        let report = BacktestEngine::new(config).backtest_with_report(lf.clone())?;
80        for (name, value) in report.metrics.row_iter() {
81            metric_cols.get_mut(name).unwrap().push(value);
82        }
83    }
84
85    let mut columns: Vec<Column> = Vec::new();
86    for key in &param_keys {
87        columns.push(Column::new(
88            PlSmallStr::from_str(key),
89            param_cols.remove(key).unwrap(),
90        ));
91    }
92    for name in PerformanceMetrics::column_names() {
93        columns.push(Column::new(
94            PlSmallStr::from_str(name),
95            metric_cols.remove(name).unwrap(),
96        ));
97    }
98
99    DataFrame::new(columns).map_err(BacktestError::from)
100}
101
102pub(crate) fn sorted_param_keys(variants: &[SweepVariant]) -> Vec<String> {
103    let mut keys: Vec<String> = variants[0].params.keys().cloned().collect();
104    keys.sort();
105    keys
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use approx::assert_relative_eq;
112
113    fn sweep_base_df() -> DataFrame {
114        DataFrame::new(vec![
115            Column::new(
116                "timestamp".into(),
117                (0..6)
118                    .map(|i| 1_700_000_000i64 + (i as i64) * 3600)
119                    .collect::<Vec<_>>(),
120            ),
121            Column::new(
122                "close".into(),
123                vec![100.0, 101.0, 102.5, 103.0, 102.0, 101.0],
124            ),
125            Column::new("signal_early".into(), vec![0.0, 1.0, 1.0, 1.0, 0.0, 0.0]),
126            Column::new("signal_late".into(), vec![0.0, 0.0, 1.0, 1.0, 0.0, 0.0]),
127            Column::new("signal_flat".into(), vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]),
128        ])
129        .unwrap()
130    }
131
132    fn zero_cost_config() -> BacktestConfig {
133        BacktestConfig {
134            cost_model: crate::CostModel {
135                commission_bps: 0.0,
136                slippage_bps: 0.0,
137                initial_cash: 100_000.0,
138            },
139            ..Default::default()
140        }
141    }
142
143    #[test]
144    fn test_sweep_single_param_returns_metrics_df() {
145        let variants = single_param_variants(
146            "threshold",
147            &[0.5, 1.0, 2.0],
148            &["signal_early", "signal_late", "signal_flat"],
149        )
150        .unwrap();
151
152        let df = run_param_sweep(sweep_base_df().lazy(), &variants, &zero_cost_config()).unwrap();
153
154        assert_eq!(df.height(), 3);
155        assert!(df.column("threshold").is_ok());
156        assert!(df.column("num_trades").is_ok());
157        assert!(df.column("final_equity").is_ok());
158        assert!(df.column("total_return").is_ok());
159
160        let thresholds = df.column("threshold").unwrap().f64().unwrap();
161        assert_relative_eq!(thresholds.get(0).unwrap(), 0.5, epsilon = 1e-9);
162        assert_relative_eq!(thresholds.get(1).unwrap(), 1.0, epsilon = 1e-9);
163        assert_relative_eq!(thresholds.get(2).unwrap(), 2.0, epsilon = 1e-9);
164
165        let trades = df.column("num_trades").unwrap().f64().unwrap();
166        assert_relative_eq!(trades.get(0).unwrap(), 1.0, epsilon = 1e-9);
167        assert_relative_eq!(trades.get(1).unwrap(), 1.0, epsilon = 1e-9);
168        assert_relative_eq!(trades.get(2).unwrap(), 0.0, epsilon = 1e-9);
169    }
170
171    #[test]
172    fn test_sweep_variants_produce_different_final_equity() {
173        let variants = single_param_variants(
174            "entry_bar",
175            &[1.0, 2.0],
176            &["signal_early", "signal_late"],
177        )
178        .unwrap();
179
180        let df = run_param_sweep(sweep_base_df().lazy(), &variants, &zero_cost_config()).unwrap();
181        assert_eq!(df.height(), 2);
182
183        let equity = df.column("final_equity").unwrap().f64().unwrap();
184        let e0 = equity.get(0).unwrap();
185        let e1 = equity.get(1).unwrap();
186        assert!(
187            (e0 - e1).abs() > 1.0,
188            "early vs late entry should differ: {e0} vs {e1}"
189        );
190    }
191
192    #[test]
193    fn test_sweep_multi_param_explicit_variants() {
194        let variants = vec![
195            SweepVariant {
196                params: HashMap::from([("stop_pct".to_string(), 0.05), ("mode".to_string(), 1.0)]),
197                signal_col: "signal_early".into(),
198            },
199            SweepVariant {
200                params: HashMap::from([("stop_pct".to_string(), 0.10), ("mode".to_string(), 1.0)]),
201                signal_col: "signal_late".into(),
202            },
203            SweepVariant {
204                params: HashMap::from([("stop_pct".to_string(), 0.05), ("mode".to_string(), 2.0)]),
205                signal_col: "signal_flat".into(),
206            },
207        ];
208
209        let df = run_param_sweep(sweep_base_df().lazy(), &variants, &zero_cost_config()).unwrap();
210        assert_eq!(df.height(), 3);
211        assert!(df.column("mode").is_ok());
212        assert!(df.column("stop_pct").is_ok());
213        assert_eq!(
214            df.column("mode").unwrap().f64().unwrap().len(),
215            3
216        );
217    }
218}