from __future__ import annotations
from pathlib import Path
import pandas as pd
import polars as pl
from wbt._df_convert import arrow_bytes_to_pd_df, pandas_to_arrow_bytes, polars_to_arrow_bytes
from wbt._wbt import PyWeightBacktest, daily_performance
STATS_FIELD_ORDER = [
"绝对收益",
"年化收益",
"夏普比率",
"卡玛比率",
"新高占比",
"单笔盈亏比",
"单笔收益",
"日胜率",
"周胜率",
"月胜率",
"季胜率",
"年胜率",
"最大回撤",
"年化波动率",
"下行波动率",
"新高间隔",
"交易次数",
"年化交易次数",
"持仓K线数",
"交易胜率",
"多头占比",
"空头占比",
"品种数量",
"开始日期",
"结束日期",
]
def _reorder_stats(d: dict[str, object]) -> dict[str, object]:
ordered: dict[str, object] = {}
for key in STATS_FIELD_ORDER:
if key in d:
ordered[key] = d[key]
for key in d:
if key not in ordered:
ordered[key] = d[key]
return ordered
def _to_date_key(value: object) -> int | None:
if value is None:
return None
if isinstance(value, int):
return value
if isinstance(value, str):
if len(value) == 8 and value.isdigit():
return int(value)
return int(pd.Timestamp(value).strftime("%Y%m%d"))
return int(pd.Timestamp(str(value)).strftime("%Y%m%d"))
WEIGH_DATA_TYPE = pd.DataFrame | pl.DataFrame | pl.LazyFrame | str | Path
class WeightBacktest:
def __init__(
self,
data: WEIGH_DATA_TYPE,
digits: int = 2,
fee_rate: float = 0.0002,
n_jobs: int = 1,
weight_type: str = "ts",
yearly_days: int = 252,
) -> None:
self.digits = digits
self.fee_rate = fee_rate
self.weight_type = weight_type
self.yearly_days = yearly_days
if isinstance(data, (str, Path)):
self.dfw = None
self._inner: PyWeightBacktest = PyWeightBacktest.from_file(
str(data), digits, fee_rate, n_jobs, weight_type, yearly_days
)
self.symbols = self._inner.symbol_dict()
else:
try:
import polars as pl
if isinstance(data, (pl.DataFrame, pl.LazyFrame)):
self.dfw = None
arrow_data = polars_to_arrow_bytes(data)
self._inner = PyWeightBacktest.from_arrow(
arrow_data, digits, fee_rate, n_jobs, weight_type, yearly_days
)
self.symbols = self._inner.symbol_dict()
return
except ImportError:
pass
dfw = data
if dfw["weight"].dtype != "float":
dfw["weight"] = dfw["weight"].astype(float)
if dfw.isnull().sum().sum() > 0:
raise ValueError(f"data 中存在空值,请先处理; 具体数据:\n{dfw[dfw.isnull().T.any().T]}")
dfw = dfw[["dt", "symbol", "weight", "price"]].copy()
dfw["weight"] = dfw["weight"].astype("float").round(digits)
self.dfw = dfw.copy()
self.symbols = list(dfw["symbol"].unique().tolist())
arrow_data = pandas_to_arrow_bytes(dfw)
self._inner = PyWeightBacktest.from_arrow(arrow_data, digits, fee_rate, n_jobs, weight_type, yearly_days)
def get_top_symbols(self, n: int = 1, kind: str = "profit") -> list[str]:
assert kind in ["profit", "loss"], "kind 只能为 'profit' 或 'loss'"
df = self.daily_return.copy()
df.drop(columns=["total"], inplace=True)
symbol_return = df.set_index("date").sum(axis=0)
symbol_return = symbol_return.sort_values(ascending=kind != "profit")
return symbol_return.head(n).index.tolist()
@property
def stats(self) -> dict:
return _reorder_stats(self._inner.stats())
@property
def symbol_dict(self) -> list:
return self._inner.symbol_dict()
def _map_symbols(self, df: pd.DataFrame) -> pd.DataFrame:
if "symbol" in df.columns and pd.api.types.is_numeric_dtype(df["symbol"]):
s_dict = dict(enumerate(self.symbol_dict))
df["symbol"] = df["symbol"].map(s_dict)
return df
@property
def daily_return(self) -> pd.DataFrame:
return self._map_symbols(arrow_bytes_to_pd_df(self._inner.daily_return()))
@property
def dailys(self) -> pd.DataFrame:
return self._map_symbols(arrow_bytes_to_pd_df(self._inner.dailys()))
@property
def alpha(self) -> pd.DataFrame:
return arrow_bytes_to_pd_df(self._inner.alpha())
def _pivot_daily_return(self, values_col: str) -> pd.DataFrame:
df = self.dailys.copy()
dfv = pd.pivot_table(df, index="date", columns="symbol", values=values_col)
if self.weight_type == "ts":
dfv["total"] = dfv.mean(axis=1)
elif self.weight_type == "cs":
dfv["total"] = dfv.sum(axis=1)
else:
raise ValueError(f"weight_type {self.weight_type} not supported")
return dfv.reset_index(drop=False)
def _compute_stats(self, df: pd.DataFrame, column: str) -> dict:
stats = daily_performance(df[column].to_numpy(), yearly_days=self.yearly_days)
stats["开始日期"] = df["date"].min().strftime("%Y-%m-%d")
stats["结束日期"] = df["date"].max().strftime("%Y-%m-%d")
return stats
@property
def alpha_stats(self) -> dict:
return self._compute_stats(self.alpha, "超额")
@property
def bench_stats(self) -> dict:
return self._compute_stats(self.alpha, "基准")
@property
def long_daily_return(self):
return self._pivot_daily_return("long_return")
@property
def short_daily_return(self):
return self._pivot_daily_return("short_return")
@property
def long_stats(self) -> dict:
return _reorder_stats(self._inner.long_stats())
@property
def short_stats(self) -> dict:
return _reorder_stats(self._inner.short_stats())
def segment_stats(
self,
sdt: str | int | pd.Timestamp | None = None,
edt: str | int | pd.Timestamp | None = None,
kind: str = "多空",
) -> dict:
sdt_int = _to_date_key(sdt)
edt_int = _to_date_key(edt)
return _reorder_stats(self._inner.segment_stats(sdt_int, edt_int, kind))
@property
def long_alpha_stats(self) -> dict:
return _reorder_stats(self._inner.long_alpha_stats())
@property
def pairs(self) -> pd.DataFrame:
return arrow_bytes_to_pd_df(self._inner.pairs())
def get_symbol_daily(self, symbol: str) -> pd.DataFrame:
df = self.dailys
return df[df["symbol"] == symbol].copy()
def get_symbol_pairs(self, symbol: str) -> pd.DataFrame:
df = self.pairs
symbol_col = "标的代码" if "标的代码" in df.columns else "symbol"
return df[df[symbol_col] == symbol].copy()
def backtest(
data: WEIGH_DATA_TYPE,
digits: int = 2,
fee_rate: float = 0.0002,
n_jobs: int = 1,
weight_type: str = "ts",
yearly_days: int = 252,
) -> WeightBacktest:
return WeightBacktest(
data, digits=digits, fee_rate=fee_rate, n_jobs=n_jobs, weight_type=weight_type, yearly_days=yearly_days
)