1use crate::{BacktestConfig, BacktestEngine, BacktestError, PerformanceMetrics};
7use polars::prelude::*;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct SweepVariant {
13 pub params: HashMap<String, f64>,
14 pub signal_col: String,
15}
16
17pub fn single_param_variants(
19 param_name: impl Into<String>,
20 param_values: &[f64],
21 signal_cols: &[impl AsRef<str>],
22) -> Result<Vec<SweepVariant>, BacktestError> {
23 if param_values.len() != signal_cols.len() {
24 return Err(BacktestError::InvalidInput(format!(
25 "param_values len {} != signal_cols len {}",
26 param_values.len(),
27 signal_cols.len()
28 )));
29 }
30 if param_values.is_empty() {
31 return Err(BacktestError::InvalidInput(
32 "sweep requires at least one variant".into(),
33 ));
34 }
35
36 let name = param_name.into();
37 Ok(param_values
38 .iter()
39 .zip(signal_cols.iter())
40 .map(|(&value, col)| SweepVariant {
41 params: HashMap::from([(name.clone(), value)]),
42 signal_col: col.as_ref().to_string(),
43 })
44 .collect())
45}
46
47pub fn run_param_sweep(
49 lf: LazyFrame,
50 variants: &[SweepVariant],
51 base_config: &BacktestConfig,
52) -> Result<DataFrame, BacktestError> {
53 if variants.is_empty() {
54 return Err(BacktestError::InvalidInput(
55 "sweep requires at least one variant".into(),
56 ));
57 }
58
59 let param_keys = sorted_param_keys(variants);
60 let mut param_cols: HashMap<String, Vec<f64>> =
61 param_keys.iter().map(|k| (k.clone(), Vec::new())).collect();
62 let mut metric_cols: HashMap<&'static str, Vec<f64>> = PerformanceMetrics::column_names()
63 .iter()
64 .map(|&name| (name, Vec::new()))
65 .collect();
66
67 for variant in variants {
68 for key in ¶m_keys {
69 let value = variant.params.get(key).copied().ok_or_else(|| {
70 BacktestError::InvalidInput(format!(
71 "variant missing param key '{key}' (expected keys: {param_keys:?})"
72 ))
73 })?;
74 param_cols.get_mut(key).unwrap().push(value);
75 }
76
77 let mut config = base_config.clone();
78 config.signal_col = variant.signal_col.clone();
79 let report = BacktestEngine::new(config).backtest_with_report(lf.clone())?;
80 for (name, value) in report.metrics.row_iter() {
81 metric_cols.get_mut(name).unwrap().push(value);
82 }
83 }
84
85 let mut columns: Vec<Column> = Vec::new();
86 for key in ¶m_keys {
87 columns.push(Column::new(
88 PlSmallStr::from_str(key),
89 param_cols.remove(key).unwrap(),
90 ));
91 }
92 for name in PerformanceMetrics::column_names() {
93 columns.push(Column::new(
94 PlSmallStr::from_str(name),
95 metric_cols.remove(name).unwrap(),
96 ));
97 }
98
99 DataFrame::new(columns).map_err(BacktestError::from)
100}
101
102pub(crate) fn sorted_param_keys(variants: &[SweepVariant]) -> Vec<String> {
103 let mut keys: Vec<String> = variants[0].params.keys().cloned().collect();
104 keys.sort();
105 keys
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use approx::assert_relative_eq;
112
113 fn sweep_base_df() -> DataFrame {
114 DataFrame::new(vec![
115 Column::new(
116 "timestamp".into(),
117 (0..6)
118 .map(|i| 1_700_000_000i64 + (i as i64) * 3600)
119 .collect::<Vec<_>>(),
120 ),
121 Column::new(
122 "close".into(),
123 vec![100.0, 101.0, 102.5, 103.0, 102.0, 101.0],
124 ),
125 Column::new("signal_early".into(), vec![0.0, 1.0, 1.0, 1.0, 0.0, 0.0]),
126 Column::new("signal_late".into(), vec![0.0, 0.0, 1.0, 1.0, 0.0, 0.0]),
127 Column::new("signal_flat".into(), vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]),
128 ])
129 .unwrap()
130 }
131
132 fn zero_cost_config() -> BacktestConfig {
133 BacktestConfig {
134 cost_model: crate::CostModel {
135 commission_bps: 0.0,
136 slippage_bps: 0.0,
137 initial_cash: 100_000.0,
138 },
139 ..Default::default()
140 }
141 }
142
143 #[test]
144 fn test_sweep_single_param_returns_metrics_df() {
145 let variants = single_param_variants(
146 "threshold",
147 &[0.5, 1.0, 2.0],
148 &["signal_early", "signal_late", "signal_flat"],
149 )
150 .unwrap();
151
152 let df = run_param_sweep(sweep_base_df().lazy(), &variants, &zero_cost_config()).unwrap();
153
154 assert_eq!(df.height(), 3);
155 assert!(df.column("threshold").is_ok());
156 assert!(df.column("num_trades").is_ok());
157 assert!(df.column("final_equity").is_ok());
158 assert!(df.column("total_return").is_ok());
159
160 let thresholds = df.column("threshold").unwrap().f64().unwrap();
161 assert_relative_eq!(thresholds.get(0).unwrap(), 0.5, epsilon = 1e-9);
162 assert_relative_eq!(thresholds.get(1).unwrap(), 1.0, epsilon = 1e-9);
163 assert_relative_eq!(thresholds.get(2).unwrap(), 2.0, epsilon = 1e-9);
164
165 let trades = df.column("num_trades").unwrap().f64().unwrap();
166 assert_relative_eq!(trades.get(0).unwrap(), 1.0, epsilon = 1e-9);
167 assert_relative_eq!(trades.get(1).unwrap(), 1.0, epsilon = 1e-9);
168 assert_relative_eq!(trades.get(2).unwrap(), 0.0, epsilon = 1e-9);
169 }
170
171 #[test]
172 fn test_sweep_variants_produce_different_final_equity() {
173 let variants = single_param_variants(
174 "entry_bar",
175 &[1.0, 2.0],
176 &["signal_early", "signal_late"],
177 )
178 .unwrap();
179
180 let df = run_param_sweep(sweep_base_df().lazy(), &variants, &zero_cost_config()).unwrap();
181 assert_eq!(df.height(), 2);
182
183 let equity = df.column("final_equity").unwrap().f64().unwrap();
184 let e0 = equity.get(0).unwrap();
185 let e1 = equity.get(1).unwrap();
186 assert!(
187 (e0 - e1).abs() > 1.0,
188 "early vs late entry should differ: {e0} vs {e1}"
189 );
190 }
191
192 #[test]
193 fn test_sweep_multi_param_explicit_variants() {
194 let variants = vec![
195 SweepVariant {
196 params: HashMap::from([("stop_pct".to_string(), 0.05), ("mode".to_string(), 1.0)]),
197 signal_col: "signal_early".into(),
198 },
199 SweepVariant {
200 params: HashMap::from([("stop_pct".to_string(), 0.10), ("mode".to_string(), 1.0)]),
201 signal_col: "signal_late".into(),
202 },
203 SweepVariant {
204 params: HashMap::from([("stop_pct".to_string(), 0.05), ("mode".to_string(), 2.0)]),
205 signal_col: "signal_flat".into(),
206 },
207 ];
208
209 let df = run_param_sweep(sweep_base_df().lazy(), &variants, &zero_cost_config()).unwrap();
210 assert_eq!(df.height(), 3);
211 assert!(df.column("mode").is_ok());
212 assert!(df.column("stop_pct").is_ok());
213 assert_eq!(
214 df.column("mode").unwrap().f64().unwrap().len(),
215 3
216 );
217 }
218}