1use crate::{BacktestConfig, BacktestEngine, BacktestError, BacktestReport};
7use polars::prelude::{RankMethod, RankOptions, *};
8
9#[derive(Debug, Clone, PartialEq)]
11pub struct CrossSectionalConfig {
12 pub factor_col: String,
13 pub top_frac: f64,
15 pub bottom_frac: f64,
17}
18
19impl CrossSectionalConfig {
20 pub fn long_short(factor_col: impl Into<String>, top_frac: f64, bottom_frac: f64) -> Self {
21 Self {
22 factor_col: factor_col.into(),
23 top_frac,
24 bottom_frac,
25 }
26 }
27}
28
29pub fn neutralize_factor(lf: LazyFrame, factor_col: &str, group_col: &str) -> LazyFrame {
31 lf.with_column(
32 (col(factor_col) - col(factor_col).mean().over([col(group_col)]))
33 .alias(factor_col)
34 )
35}
36
37pub fn zscore_factor(lf: LazyFrame, factor_col: &str, timestamp_col: &str) -> LazyFrame {
39 let mean = col(factor_col).mean().over([col(timestamp_col)]);
40 let std = col(factor_col).std(1).over([col(timestamp_col)]);
41 lf.with_column(
42 ((col(factor_col) - mean) / std).alias(factor_col)
43 )
44}
45
46pub fn winsorize_factor(
48 lf: LazyFrame,
49 factor_col: &str,
50 timestamp_col: &str,
51 lower_pct: f64,
52 upper_pct: f64,
53) -> LazyFrame {
54 let lower = col(factor_col)
55 .quantile(lit(lower_pct), QuantileMethod::Nearest)
56 .over([col(timestamp_col)]);
57 let upper = col(factor_col)
58 .quantile(lit(upper_pct), QuantileMethod::Nearest)
59 .over([col(timestamp_col)]);
60
61 lf.with_column(
62 when(col(factor_col).lt(lower.clone()))
63 .then(lower)
64 .when(col(factor_col).gt(upper.clone()))
65 .then(upper)
66 .otherwise(col(factor_col))
67 .alias(factor_col)
68 )
69}
70
71pub fn assign_long_short_exposure(
75 lf: LazyFrame,
76 timestamp_col: &str,
77 _symbol_col: &str,
78 cs: &CrossSectionalConfig,
79 exposure_col: &str,
80) -> Result<LazyFrame, BacktestError> {
81 if cs.top_frac <= 0.0 || cs.bottom_frac <= 0.0 {
82 return Err(BacktestError::InvalidInput(
83 "top_frac and bottom_frac must be > 0".into(),
84 ));
85 }
86 if cs.top_frac + cs.bottom_frac > 1.0 {
87 return Err(BacktestError::InvalidInput(
88 "top_frac + bottom_frac must be <= 1".into(),
89 ));
90 }
91
92 let n_per_ts = col(timestamp_col)
93 .count()
94 .over([col(timestamp_col)])
95 .cast(DataType::Float64);
96
97 let rank_best = col(&cs.factor_col)
99 .rank(
100 RankOptions {
101 method: RankMethod::Min,
102 descending: true,
103 },
104 None,
105 )
106 .over([col(timestamp_col)])
107 .cast(DataType::Float64);
108
109 let top_slots = when(
110 (n_per_ts.clone() * lit(cs.top_frac))
111 .cast(DataType::Int64)
112 .cast(DataType::Float64)
113 .lt(lit(1.0)),
114 )
115 .then(lit(1.0))
116 .otherwise(
117 (n_per_ts.clone() * lit(cs.top_frac))
118 .cast(DataType::Int64)
119 .cast(DataType::Float64),
120 );
121 let bottom_slots = when(
122 (n_per_ts.clone() * lit(cs.bottom_frac))
123 .cast(DataType::Int64)
124 .cast(DataType::Float64)
125 .lt(lit(1.0)),
126 )
127 .then(lit(1.0))
128 .otherwise(
129 (n_per_ts.clone() * lit(cs.bottom_frac))
130 .cast(DataType::Int64)
131 .cast(DataType::Float64),
132 );
133 let short_cut = n_per_ts - bottom_slots.clone() + lit(1.0);
134
135 let exposure = when(rank_best.clone().lt_eq(top_slots.clone()))
136 .then(lit(1.0) / top_slots.clone())
137 .when(rank_best.gt_eq(short_cut))
138 .then(lit(-1.0) / bottom_slots)
139 .otherwise(lit(0.0))
140 .alias(exposure_col);
141
142 Ok(lf.with_column(exposure))
143}
144
145pub fn run_cross_sectional_backtest(
147 lf: LazyFrame,
148 cs: &CrossSectionalConfig,
149 mut base_config: BacktestConfig,
150) -> Result<BacktestReport, BacktestError> {
151 let symbol_col = base_config
152 .symbol_col
153 .clone()
154 .ok_or_else(|| BacktestError::InvalidInput("symbol_col required for cross-sectional".into()))?;
155
156 const EXPOSURE: &str = "cs_exposure";
157 let with_exp = assign_long_short_exposure(
158 lf,
159 &base_config.timestamp_col,
160 &symbol_col,
161 cs,
162 EXPOSURE,
163 )?;
164 base_config.signal_col = EXPOSURE.to_string();
165 BacktestEngine::new(base_config).backtest_with_report(with_exp)
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use approx::assert_relative_eq;
172
173 fn panel_df() -> DataFrame {
174 let timestamps = vec![1i64, 1, 1, 1, 2, 2, 2, 2];
176 let symbols = vec!["A", "B", "C", "D", "A", "B", "C", "D"];
177 let closes = vec![10.0, 10.0, 10.0, 10.0, 11.0, 11.0, 11.0, 11.0];
178 let factor = vec![4.0, 3.0, 2.0, 1.0, 4.0, 3.0, 2.0, 1.0];
179 DataFrame::new(vec![
180 Column::new("timestamp".into(), timestamps),
181 Column::new("symbol".into(), symbols),
182 Column::new("close".into(), closes),
183 Column::new("score".into(), factor),
184 ])
185 .unwrap()
186 }
187
188 #[test]
189 fn test_factor_neutralize_demean_within_sector() {
190 let df = DataFrame::new(vec![
191 Column::new("sector".into(), vec!["Tech", "Tech", "Fin", "Fin"]),
192 Column::new("score".into(), vec![10.0, 20.0, 100.0, 200.0]),
193 ]).unwrap();
194
195 let out = neutralize_factor(df.lazy(), "score", "sector").collect().unwrap();
196 let scores = out.column("score").unwrap().f64().unwrap();
197
198 let ts: Vec<f64> = scores.into_iter().map(|v| v.unwrap()).collect();
199 assert_relative_eq!(ts[0], -5.0, epsilon = 1e-9);
200 assert_relative_eq!(ts[1], 5.0, epsilon = 1e-9);
201 assert_relative_eq!(ts[2], -50.0, epsilon = 1e-9);
202 assert_relative_eq!(ts[3], 50.0, epsilon = 1e-9);
203 }
204
205 #[test]
206 fn test_factor_zscore_zero_mean_smoke() {
207 let df = panel_df();
208 let out = zscore_factor(df.lazy(), "score", "timestamp").collect().unwrap();
209 let scores = out.column("score").unwrap().f64().unwrap();
210 let ts: Vec<f64> = scores.into_iter().map(|v| v.unwrap()).collect();
211
212 let mean_1 = (ts[0] + ts[1] + ts[2] + ts[3]) / 4.0;
215 assert_relative_eq!(mean_1, 0.0, epsilon = 1e-9);
216 }
217
218 #[test]
219 fn test_factor_winsorize_clips_extremes() {
220 let df = DataFrame::new(vec![
221 Column::new("timestamp".into(), vec![1i64, 1, 1, 1, 1]),
222 Column::new("score".into(), vec![0.0, 10.0, 20.0, 30.0, 100.0]),
223 ]).unwrap();
224
225 let out = winsorize_factor(df.lazy(), "score", "timestamp", 0.2, 0.8).collect().unwrap();
226 let scores = out.column("score").unwrap().f64().unwrap();
227 let ts: Vec<f64> = scores.into_iter().map(|v| v.unwrap()).collect();
228
229 assert_relative_eq!(ts[0], 10.0, epsilon = 1e-9); assert_relative_eq!(ts[1], 10.0, epsilon = 1e-9);
234 assert_relative_eq!(ts[4], 30.0, epsilon = 1e-9); }
236
237 #[test]
238 fn test_assign_long_short_exposure_top_bottom() {
239 let cs = CrossSectionalConfig::long_short("score", 0.25, 0.25);
240 let out = assign_long_short_exposure(
241 panel_df().lazy(),
242 "timestamp",
243 "symbol",
244 &cs,
245 "exposure",
246 )
247 .unwrap()
248 .collect()
249 .unwrap();
250
251 let exposure = out.column("exposure").unwrap().f64().unwrap();
252 let ts1: Vec<f64> = exposure.into_iter().take(4).map(|v| v.unwrap()).collect();
254 assert_eq!(ts1.iter().filter(|&&x| x > 0.0).count(), 1);
255 assert_eq!(ts1.iter().filter(|&&x| x < 0.0).count(), 1);
256 assert_relative_eq!(ts1.iter().map(|x| x.abs()).sum::<f64>(), 2.0, epsilon = 1e-9);
257 }
258
259 #[test]
260 fn test_cross_sectional_backtest_smoke() {
261 let cs = CrossSectionalConfig::long_short("score", 0.25, 0.25);
262 let cfg = BacktestConfig {
263 cost_model: crate::CostModel {
264 commission_bps: 0.0,
265 slippage_bps: 0.0,
266 initial_cash: 100_000.0,
267 },
268 symbol_col: Some("symbol".into()),
269 ..Default::default()
270 };
271 let mut df = panel_df();
273 df = df
274 .lazy()
275 .with_column(lit(1.0).alias("score"))
276 .collect()
277 .unwrap();
278
279 let report = run_cross_sectional_backtest(df.lazy(), &cs, cfg.clone()).unwrap();
280 assert!(report.metrics.final_equity.is_finite());
281 }
282
283 #[test]
284 fn test_cross_sectional_invalid_fracs_error() {
285 let cs = CrossSectionalConfig::long_short("score", 0.6, 0.6);
286 match assign_long_short_exposure(
287 panel_df().lazy(),
288 "timestamp",
289 "symbol",
290 &cs,
291 "exposure",
292 ) {
293 Err(e) => assert!(e.to_string().contains("top_frac")),
294 Ok(_) => panic!("expected invalid frac error"),
295 }
296 }
297}