use crate::core::native_engine::{DailysSoA, PairsSoA};
use anyhow::Context;
use errors::WbtError;
use polars::prelude::*;
use report::Report;
use std::path::Path;
mod backtest;
pub mod daily_performance;
pub mod errors;
mod evaluate_pairs;
pub mod native_engine;
pub mod period_win_rates;
mod report;
pub mod top_drawdowns;
pub mod trade_dir;
pub mod utils;
pub mod yearly_return;
pub use utils::WeightType;
pub struct WeightBacktest {
pub dfw: DataFrame,
pub digits: i64,
pub fee_rate: f64,
pub symbols: Vec<Arc<str>>,
dailys_soa: Option<DailysSoA>,
pairs_soa: Option<PairsSoA>,
daily_return_cache: Option<DataFrame>,
dailys_cache: Option<DataFrame>,
pairs_cache: Option<DataFrame>,
weight_type: Option<WeightType>,
pub report: Option<Report>,
pub yearly_days: usize,
}
impl WeightBacktest {
pub fn new(dfw: DataFrame, digits: i64, fee_rate: Option<f64>) -> Result<Self, WbtError> {
let mut dfw = Self::convert_datetime(dfw).context("Failed to convert datetime")?;
Self::round_weight(&mut dfw).context("Failed to round weight")?;
let symbols = Self::unique_symbols(&dfw).context("Failed to unique_symbols")?;
let dfw = {
let n_rows = dfw.height();
let n_syms = symbols.len();
let mut order_map: hashbrown::HashMap<&str, u32> =
hashbrown::HashMap::with_capacity(n_syms);
for (idx, sym) in symbols.iter().enumerate() {
order_map.insert(sym.as_ref(), idx as u32);
}
let sym_ca = dfw.column("symbol")?.as_materialized_series().str()?;
let sym_ids: Vec<u32> = sym_ca
.into_iter()
.map(|opt_s| opt_s.and_then(|s| order_map.get(s).copied()).unwrap_or(0))
.collect();
drop(order_map);
let mut bucket_counts = vec![0u32; n_syms];
for &sid in &sym_ids {
bucket_counts[sid as usize] += 1;
}
let mut write_pos = vec![0u32; n_syms];
let mut acc = 0u32;
for i in 0..n_syms {
write_pos[i] = acc;
acc += bucket_counts[i];
}
let mut perm = vec![0u32; n_rows];
for (i, &sid_val) in sym_ids.iter().enumerate().take(n_rows) {
let sid = sid_val as usize;
perm[write_pos[sid] as usize] = i as u32;
write_pos[sid] += 1;
}
let perm_idx = IdxCa::new(PlSmallStr::from("idx"), &perm);
let sym_id_vals: Vec<u32> = perm.iter().map(|&i| sym_ids[i as usize]).collect();
DataFrame::new(vec![
Column::new("sym_id".into(), sym_id_vals),
dfw.column("dt")?
.as_materialized_series()
.take(&perm_idx)?
.into_column(),
dfw.column("weight")?
.as_materialized_series()
.take(&perm_idx)?
.into_column(),
dfw.column("price")?
.as_materialized_series()
.take(&perm_idx)?
.into_column(),
dfw.column("symbol")?
.as_materialized_series()
.take(&perm_idx)?
.into_column(),
])?
};
let wb = Self {
dfw,
digits,
symbols,
fee_rate: fee_rate.unwrap_or(0.0002),
dailys_soa: None,
pairs_soa: None,
daily_return_cache: None,
dailys_cache: None,
pairs_cache: None,
weight_type: None,
report: None,
yearly_days: 252,
};
Ok(wb)
}
pub fn from_file(path: &str, digits: i64, fee_rate: Option<f64>) -> Result<Self, WbtError> {
let p = Path::new(path);
let ext = p
.extension()
.and_then(|e| e.to_str())
.unwrap_or("")
.to_lowercase();
let df = match ext.as_str() {
"csv" => {
CsvReader::new(std::fs::File::open(p).map_err(|e| WbtError::Io(e.to_string()))?)
.finish()
.map_err(WbtError::Polars)?
}
"parquet" => {
let file = std::fs::File::open(p).map_err(|e| WbtError::Io(e.to_string()))?;
ParquetReader::new(file)
.finish()
.map_err(WbtError::Polars)?
}
"feather" | "arrow" => {
let file = std::fs::File::open(p).map_err(|e| WbtError::Io(e.to_string()))?;
IpcReader::new(file).finish().map_err(WbtError::Polars)?
}
_ => {
return Err(WbtError::Io(format!(
"Unsupported file format: '{}'. Supported: csv, parquet, feather, arrow",
ext
)));
}
};
let required = ["dt", "symbol", "weight", "price"];
for col in required {
if df.column(col).is_err() {
return Err(WbtError::Io(format!(
"Missing required column '{}' in file '{}'",
col, path
)));
}
}
Self::new(df, digits, fee_rate)
}
pub fn backtest(
&mut self,
n_jobs: Option<usize>,
weight_type: WeightType,
yearly_days: usize,
) -> Result<(), WbtError> {
let n_jobs = n_jobs.unwrap_or(4);
let pool = rayon::ThreadPoolBuilder::new()
.stack_size(64 * 1024 * 1024)
.num_threads(n_jobs)
.build()
.context("Failed to create thread pool")?;
pool.install(|| self.do_backtest(weight_type, yearly_days))
}
pub fn daily_return_df(&mut self) -> Result<&mut DataFrame, WbtError> {
if self.daily_return_cache.is_none() {
let dailys_soa = self
.dailys_soa
.as_ref()
.ok_or_else(|| WbtError::NoneValue("dailys_soa not computed yet".into()))?;
let report = self
.report
.as_ref()
.ok_or_else(|| WbtError::NoneValue("report not computed yet".into()))?;
let weight_type = self
.weight_type
.ok_or_else(|| WbtError::NoneValue("weight_type not computed yet".into()))?;
let df = Self::build_daily_return_df(dailys_soa, &report.daily_totals, weight_type)?;
self.daily_return_cache = Some(df);
}
Ok(self.daily_return_cache.as_mut().unwrap())
}
pub fn dailys_df(&mut self) -> Result<&mut DataFrame, WbtError> {
if self.dailys_cache.is_none() {
let df = self
.dailys_soa
.as_ref()
.ok_or_else(|| WbtError::NoneValue("dailys_soa not computed yet".into()))?
.to_dataframe()?;
self.dailys_cache = Some(df);
}
Ok(self.dailys_cache.as_mut().unwrap())
}
pub fn pairs_df(&mut self) -> Result<Option<&mut DataFrame>, WbtError> {
if self.pairs_soa.is_none() {
return Ok(None);
}
if self.pairs_cache.is_none() {
let df = self.pairs_soa.as_ref().unwrap().to_dataframe()?;
self.pairs_cache = Some(df);
}
Ok(self.pairs_cache.as_mut())
}
pub fn yearly_return_df(&mut self, min_days: usize) -> Result<DataFrame, WbtError> {
let wide = self.daily_return_df()?;
yearly_return::compute_yearly_returns(wide, min_days)
}
pub fn alpha_df(&self) -> Result<DataFrame, WbtError> {
let report = self
.report
.as_ref()
.ok_or_else(|| WbtError::NoneValue("report not computed yet".into()))?;
let dt = &report.daily_totals;
let n = dt.strategy_means.len();
let epoch = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
let dr_dates: Vec<i32> = dt
.date_keys
.iter()
.map(|dk| {
let nd = utils::date_key_to_naive_date(*dk);
(nd - epoch).num_days() as i32
})
.collect();
let excess: Vec<f64> = (0..n)
.map(|i| dt.strategy_means[i] - dt.benchmark_means[i])
.collect();
DataFrame::new(vec![
Series::new("date".into(), dr_dates)
.cast(&DataType::Date)
.map_err(WbtError::Polars)?
.into_column(),
Series::new("超额".into(), excess).into_column(),
Series::new("策略".into(), &dt.strategy_means).into_column(),
Series::new("基准".into(), &dt.benchmark_means).into_column(),
])
.map_err(WbtError::Polars)
}
}
impl WeightBacktest {
pub(crate) fn unique_symbols(df: &DataFrame) -> Result<Vec<Arc<str>>, WbtError> {
let symbols_series = df.column("symbol")?.as_materialized_series().str()?;
let mut unique_symbols_set = hashbrown::HashSet::new();
for symbol in symbols_series.into_iter().flatten() {
unique_symbols_set.insert(symbol);
}
let mut unique_symbols: Vec<Arc<str>> =
unique_symbols_set.into_iter().map(Arc::from).collect();
unique_symbols.sort_unstable();
Ok(unique_symbols)
}
fn sort_by_dt(df: DataFrame) -> Result<DataFrame, WbtError> {
df.lazy()
.sort(
["dt"],
SortMultipleOptions::default().with_order_descending(false),
)
.collect()
.map_err(|e| anyhow::anyhow!("Failed to sort by dt: {e}").into())
}
pub(crate) fn convert_datetime(mut df: DataFrame) -> Result<DataFrame, WbtError> {
let dt_col = df.column("dt")?.as_materialized_series().clone();
let dt_type = dt_col.dtype().clone();
match &dt_type {
DataType::Datetime(TimeUnit::Nanoseconds, _) => Ok(Self::sort_by_dt(df)?),
DataType::Datetime(TimeUnit::Milliseconds, _) => {
let dt_cast = dt_col.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?;
let _ = df.replace("dt", dt_cast)?;
Ok(Self::sort_by_dt(df)?)
}
DataType::Int64 => {
let parsed_col = dt_col
.i64()?
.into_iter()
.map(|opt_ts| opt_ts.map(|ts| ts * 1000));
let dt_s = Series::from_iter(parsed_col)
.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?;
let _ = df.replace("dt", dt_s)?;
Ok(Self::sort_by_dt(df)?)
}
DataType::String => {
let df = df
.lazy()
.with_column(col("dt").str().to_datetime(
Some(TimeUnit::Milliseconds),
None,
StrptimeOptions {
format: Some("%Y-%m-%d %H:%M:%S".into()),
strict: true,
exact: false,
cache: true,
},
lit("raise"),
))
.sort(
["dt"],
SortMultipleOptions::default().with_order_descending(false),
)
.collect()
.context("Failed to convert datetime")?;
Ok(df)
}
_ => Err(anyhow::anyhow!("Unsupported datetime type: {:?}", dt_type).into()),
}
}
pub(crate) fn round_weight(df: &mut DataFrame) -> Result<(), WbtError> {
let weight_s = df.column("weight")?.as_materialized_series().clone();
let rounded = weight_s
.f64()
.unwrap()
.into_iter()
.map(|opt| opt.map(|val| (val * 10000.0).round() / 10000.0))
.collect::<Float64Chunked>();
let _ = df.replace("weight", rounded.into_series())?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn raw_example_data() -> DataFrame {
df! {
"dt" => &[
"2019-01-02 09:01:00",
"2019-01-03 09:02:00",
"2019-01-04 09:03:00",
"2019-01-05 09:04:00",
"2019-01-06 09:05:00"
],
"symbol" => &["DLi9001"; 5],
"weight" => &[
0.511,
0.000,
-0.250,
0.000,
0.000
],
"price" => &[
961.695,
960.720,
962.669,
960.720,
961.695
]
}
.unwrap()
}
#[test]
fn test_round_weight() {
let mut df = raw_example_data();
WeightBacktest::round_weight(&mut df).unwrap();
let weights: Vec<f64> = df
.column("weight")
.unwrap()
.as_materialized_series()
.f64()
.unwrap()
.into_no_null_iter()
.collect();
assert_eq!(weights, vec![0.511, 0.0, -0.25, 0.0, 0.0]);
}
#[test]
fn test_convert_datetime() {
let df = raw_example_data();
let df = WeightBacktest::convert_datetime(df).unwrap();
assert!(matches!(
df.column("dt").unwrap().dtype(),
DataType::Datetime(_, _)
));
assert_eq!(df.height(), 5);
}
#[test]
fn test_unique_symbols() {
let df = raw_example_data();
let symbols = WeightBacktest::unique_symbols(&df).unwrap();
assert_eq!(symbols, vec![Arc::from("DLi9001")]);
}
#[test]
fn new_valid_dataframe() {
let df = raw_example_data();
let wb = WeightBacktest::new(df, 2, None).unwrap();
assert_eq!(wb.fee_rate, 0.0002);
assert_eq!(wb.digits, 2);
assert!(!wb.symbols.is_empty());
}
#[test]
fn new_custom_fee_rate() {
let df = raw_example_data();
let wb = WeightBacktest::new(df, 2, Some(0.001)).unwrap();
assert_eq!(wb.fee_rate, 0.001);
}
#[test]
fn yearly_return_df_end_to_end_minimal() {
let df = raw_example_data();
let mut wb = WeightBacktest::new(df, 2, None).unwrap();
wb.backtest(Some(1), WeightType::TS, 252).unwrap();
let y = wb.yearly_return_df(1).unwrap();
assert_eq!(y.height(), 2);
let years: Vec<i32> = y
.column("year")
.unwrap()
.as_materialized_series()
.i32()
.unwrap()
.into_no_null_iter()
.collect();
assert_eq!(years, vec![2019, 2019]);
let syms: Vec<String> = y
.column("symbol")
.unwrap()
.as_materialized_series()
.str()
.unwrap()
.into_no_null_iter()
.map(|s: &str| s.to_string())
.collect();
assert_eq!(syms, vec!["DLi9001".to_string(), "total".to_string()]);
}
#[test]
fn yearly_return_df_filters_when_below_min_days() {
let df = raw_example_data();
let mut wb = WeightBacktest::new(df, 2, None).unwrap();
wb.backtest(Some(1), WeightType::TS, 252).unwrap();
let y = wb.yearly_return_df(120).unwrap();
assert_eq!(y.height(), 0);
}
#[test]
fn daily_return_cache_is_lazy_and_reused() {
let df = raw_example_data();
let mut wb = WeightBacktest::new(df, 2, None).unwrap();
wb.backtest(Some(1), WeightType::TS, 252).unwrap();
assert!(wb.daily_return_cache.is_none());
let first_ptr = {
let df = wb.daily_return_df().unwrap();
df as *mut DataFrame
};
assert!(wb.daily_return_cache.is_some());
let second_ptr = {
let df = wb.daily_return_df().unwrap();
df as *mut DataFrame
};
assert_eq!(first_ptr, second_ptr);
}
#[test]
fn new_missing_column() {
let df = df! {
"dt" => &["2019-01-02 09:01:00"],
"symbol" => &["A"],
"weight" => &[0.5_f64]
}
.unwrap();
assert!(WeightBacktest::new(df, 2, None).is_err());
}
#[test]
fn convert_datetime_int64() {
let df = df! {
"dt" => &[1546398060_i64, 1546484520_i64],
"symbol" => &["A", "A"],
"weight" => &[0.5_f64, -0.5],
"price" => &[100.0, 101.0]
}
.unwrap();
let result = WeightBacktest::convert_datetime(df);
assert!(result.is_ok());
let df = result.unwrap();
assert!(matches!(
df.column("dt").unwrap().dtype(),
DataType::Datetime(_, _)
));
}
#[test]
fn round_weight_precision() {
let mut df = df! {
"dt" => &["2019-01-02 09:01:00"],
"symbol" => &["A"],
"weight" => &[0.12345678_f64],
"price" => &[100.0]
}
.unwrap();
WeightBacktest::round_weight(&mut df).unwrap();
let w = df
.column("weight")
.unwrap()
.as_materialized_series()
.f64()
.unwrap()
.get(0)
.unwrap();
assert_eq!(w, 0.1235);
}
#[test]
fn round_weight_zero() {
let mut df = df! {
"dt" => &["2019-01-02 09:01:00"],
"symbol" => &["A"],
"weight" => &[0.0_f64],
"price" => &[100.0]
}
.unwrap();
WeightBacktest::round_weight(&mut df).unwrap();
let w = df
.column("weight")
.unwrap()
.as_materialized_series()
.f64()
.unwrap()
.get(0)
.unwrap();
assert_eq!(w, 0.0);
}
#[test]
fn from_file_csv() {
let dir = std::env::temp_dir().join("wbt_test_from_file");
std::fs::create_dir_all(&dir).unwrap();
let csv_path = dir.join("test.csv");
let csv_content = "dt,symbol,weight,price\n\
2024-01-01 09:30:00,SYM_A,0.5,100.0\n\
2024-01-02 09:30:00,SYM_A,-0.3,101.0\n\
2024-01-01 09:30:00,SYM_B,0.2,50.0\n\
2024-01-02 09:30:00,SYM_B,0.0,51.0\n";
std::fs::write(&csv_path, csv_content).unwrap();
let wb = WeightBacktest::from_file(csv_path.to_str().unwrap(), 2, None).unwrap();
assert_eq!(wb.symbols.len(), 2);
assert!(wb.symbols.contains(&std::sync::Arc::from("SYM_A")));
assert!(wb.symbols.contains(&std::sync::Arc::from("SYM_B")));
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn from_file_missing_column() {
let dir = std::env::temp_dir().join("wbt_test_missing_col");
std::fs::create_dir_all(&dir).unwrap();
let csv_path = dir.join("bad.csv");
std::fs::write(&csv_path, "dt,symbol,weight\n2024-01-01,A,0.5\n").unwrap();
let result = WeightBacktest::from_file(csv_path.to_str().unwrap(), 2, None);
assert!(result.is_err());
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn from_file_unsupported_ext() {
let result = WeightBacktest::from_file("/tmp/test.xlsx", 2, None);
assert!(result.is_err());
let err_msg = match result {
Err(e) => e.to_string(),
Ok(_) => unreachable!(),
};
assert!(err_msg.contains("Unsupported"));
}
#[test]
fn from_file_parquet() {
let dir = std::env::temp_dir().join("wbt_test_from_file_parquet");
std::fs::create_dir_all(&dir).unwrap();
let path = dir.join("test.parquet");
let df = df! {
"dt" => &[
"2024-01-01 09:30:00",
"2024-01-02 09:30:00",
"2024-01-01 09:30:00",
"2024-01-02 09:30:00",
],
"symbol" => &["SYM_A", "SYM_A", "SYM_B", "SYM_B"],
"weight" => &[0.5_f64, -0.3, 0.2, 0.0],
"price" => &[100.0_f64, 101.0, 50.0, 51.0]
}
.unwrap();
let file = std::fs::File::create(&path).unwrap();
ParquetWriter::new(file).finish(&mut df.clone()).unwrap();
let wb = WeightBacktest::from_file(path.to_str().unwrap(), 2, None).unwrap();
assert_eq!(wb.symbols.len(), 2);
assert!(wb.symbols.contains(&std::sync::Arc::from("SYM_A")));
assert!(wb.symbols.contains(&std::sync::Arc::from("SYM_B")));
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn from_file_feather() {
let dir = std::env::temp_dir().join("wbt_test_from_file_feather");
std::fs::create_dir_all(&dir).unwrap();
let path = dir.join("test.feather");
let df = df! {
"dt" => &[
"2024-01-01 09:30:00",
"2024-01-02 09:30:00",
"2024-01-01 09:30:00",
"2024-01-02 09:30:00",
],
"symbol" => &["SYM_A", "SYM_A", "SYM_B", "SYM_B"],
"weight" => &[0.5_f64, -0.3, 0.2, 0.0],
"price" => &[100.0_f64, 101.0, 50.0, 51.0]
}
.unwrap();
let file = std::fs::File::create(&path).unwrap();
IpcWriter::new(file).finish(&mut df.clone()).unwrap();
let wb = WeightBacktest::from_file(path.to_str().unwrap(), 2, None).unwrap();
assert_eq!(wb.symbols.len(), 2);
assert!(wb.symbols.contains(&std::sync::Arc::from("SYM_A")));
assert!(wb.symbols.contains(&std::sync::Arc::from("SYM_B")));
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn unique_symbols_sorted_order() {
let df = df! {
"dt" => &["2019-01-02", "2019-01-02", "2019-01-02"],
"symbol" => &["C", "A", "B"],
"weight" => &[0.1, 0.2, 0.3],
"price" => &[1.0, 2.0, 3.0]
}
.unwrap();
let syms = WeightBacktest::unique_symbols(&df).unwrap();
assert_eq!(syms, vec![Arc::from("A"), Arc::from("B"), Arc::from("C")]);
}
#[test]
fn convert_datetime_nanoseconds_passthrough() {
let dates: Vec<i64> = vec![
1_704_067_200_000_000_000, 1_704_153_600_000_000_000, ];
let dt_series = Series::new("dt".into(), dates)
.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))
.unwrap();
let df = DataFrame::new(vec![
dt_series.into_column(),
Series::new("symbol".into(), &["A", "A"]).into_column(),
Series::new("weight".into(), &[0.5_f64, 0.0]).into_column(),
Series::new("price".into(), &[100.0_f64, 101.0]).into_column(),
])
.unwrap();
let result = WeightBacktest::convert_datetime(df);
assert!(result.is_ok());
let df = result.unwrap();
assert!(matches!(
df.column("dt").unwrap().dtype(),
DataType::Datetime(TimeUnit::Nanoseconds, _)
));
assert_eq!(df.height(), 2);
}
}