from __future__ import annotations
import pandas as pd
import pytest
from wbt import WeightBacktest
STATS_KEYS_25 = [
"开始日期",
"结束日期",
"绝对收益",
"年化收益",
"夏普比率",
"卡玛比率",
"新高占比",
"单笔盈亏比",
"单笔收益",
"日胜率",
"周胜率",
"月胜率",
"季胜率",
"年胜率",
"最大回撤",
"年化波动率",
"下行波动率",
"新高间隔",
"交易次数",
"年化交易次数",
"持仓K线数",
"交易胜率",
"多头占比",
"空头占比",
"品种数量",
]
PERF_KEYS_17 = [
"绝对收益",
"年化",
"夏普",
"最大回撤",
"卡玛",
"日胜率",
"日盈亏比",
"日赢面",
"年化波动率",
"下行波动率",
"非零覆盖",
"盈亏平衡点",
"新高间隔",
"新高占比",
"回撤风险",
"回归年度回报率",
"长度调整平均最大回撤",
]
class TestWeightBacktestInit:
def test_creates_successfully(self, wb: WeightBacktest) -> None:
assert wb.digits == 2
assert wb.fee_rate == pytest.approx(0.0002)
assert wb.weight_type == "ts"
assert set(wb.symbols) == {"SYM_A", "SYM_B"}
class TestStats:
def test_stats_keys(self, wb: WeightBacktest) -> None:
stats = wb.stats
assert isinstance(stats, dict)
assert len(stats) == 25
for key in STATS_KEYS_25:
assert key in stats, f"missing key: {key}"
def test_stats_date_format(self, wb: WeightBacktest) -> None:
stats = wb.stats
assert isinstance(stats["开始日期"], str)
assert len(stats["开始日期"]) == 10
def test_stats_values_consistency(self, wb: WeightBacktest) -> None:
stats = wb.stats
assert stats["品种数量"] == 2
assert 0 <= stats["多头占比"] <= 1.0
assert 0 <= stats["空头占比"] <= 1.0
assert 0 <= stats["日胜率"] <= 1.0
assert 0 <= stats["交易胜率"] <= 1.0
assert stats["最大回撤"] >= 0
assert stats["年化波动率"] >= 0
assert "年化收益" in stats
assert "夏普比率" in stats
assert "卡玛比率" in stats
class TestSymbolDict:
def test_symbol_dict(self, wb: WeightBacktest) -> None:
sd = wb.symbol_dict
assert isinstance(sd, list)
assert len(sd) == 2
class TestDailyReturn:
def test_structure(self, wb: WeightBacktest) -> None:
dr = wb.daily_return
assert isinstance(dr, pd.DataFrame)
assert "date" in dr.columns
assert "total" in dr.columns
assert len(dr) > 0
class TestDailys:
def test_columns(self, wb: WeightBacktest) -> None:
df = wb.dailys
assert isinstance(df, pd.DataFrame)
expected_cols = [
"symbol",
"date",
"n1b",
"edge",
"return",
"cost",
"turnover",
"long_edge",
"short_edge",
"long_cost",
"short_cost",
"long_turnover",
"short_turnover",
"long_return",
"short_return",
]
for col in expected_cols:
assert col in df.columns, f"missing: {col}"
def test_return_equals_edge_minus_cost(self, wb: WeightBacktest) -> None:
df = wb.dailys
expected = df["edge"] - df["cost"]
pd.testing.assert_series_equal(df["return"], expected, check_names=False, atol=1e-8)
def test_long_short_edge_consistency(self, wb: WeightBacktest) -> None:
df = wb.dailys
expected = df["long_edge"] + df["short_edge"]
pd.testing.assert_series_equal(df["edge"], expected, check_names=False, atol=1e-8)
class TestAlpha:
def test_structure(self, wb: WeightBacktest) -> None:
df = wb.alpha
assert isinstance(df, pd.DataFrame)
assert list(df.columns) == ["date", "超额", "策略", "基准"]
def test_alpha_equals_strategy_minus_benchmark(self, wb: WeightBacktest) -> None:
df = wb.alpha
expected = df["策略"] - df["基准"]
pd.testing.assert_series_equal(df["超额"], expected, check_names=False, atol=1e-10)
class TestPairs:
def test_structure(self, wb: WeightBacktest) -> None:
df = wb.pairs
assert isinstance(df, pd.DataFrame)
if len(df) > 0:
assert "symbol" in df.columns
assert "交易方向" in df.columns
class TestAlphaAndBenchStats:
def test_alpha_stats(self, wb: WeightBacktest) -> None:
stats = wb.alpha_stats
assert isinstance(stats, dict)
assert "开始日期" in stats
assert "结束日期" in stats
for key in PERF_KEYS_17:
assert key in stats
def test_bench_stats(self, wb: WeightBacktest) -> None:
stats = wb.bench_stats
assert isinstance(stats, dict)
for key in PERF_KEYS_17:
assert key in stats
class TestLongShortReturns:
def test_long_daily_return(self, wb: WeightBacktest) -> None:
df = wb.long_daily_return
assert isinstance(df, pd.DataFrame)
assert "total" in df.columns
def test_short_daily_return(self, wb: WeightBacktest) -> None:
df = wb.short_daily_return
assert isinstance(df, pd.DataFrame)
assert "total" in df.columns
def test_long_stats(self, wb: WeightBacktest) -> None:
stats = wb.long_stats
assert isinstance(stats, dict)
assert "年化收益" in stats
assert "夏普比率" in stats
assert "交易次数" in stats
def test_short_stats(self, wb: WeightBacktest) -> None:
stats = wb.short_stats
assert isinstance(stats, dict)
assert "年化收益" in stats
assert "夏普比率" in stats
class TestSegmentStats:
def test_segment_stats_default(self, wb: WeightBacktest) -> None:
stats = wb.segment_stats()
assert isinstance(stats, dict)
assert "年化收益" in stats
assert "交易次数" in stats
def test_segment_stats_long(self, wb: WeightBacktest) -> None:
stats = wb.segment_stats(kind="多头")
assert isinstance(stats, dict)
assert "年化收益" in stats
def test_segment_stats_short(self, wb: WeightBacktest) -> None:
stats = wb.segment_stats(kind="空头")
assert isinstance(stats, dict)
assert "年化收益" in stats
class TestLongAlphaStats:
def test_long_alpha_stats(self, wb: WeightBacktest) -> None:
stats = wb.long_alpha_stats
assert isinstance(stats, dict)
assert "年化收益" in stats
assert "夏普比率" in stats
class TestSymbolMethods:
def test_get_top_symbols_profit(self, wb: WeightBacktest) -> None:
result = wb.get_top_symbols(n=1, kind="profit")
assert isinstance(result, list)
def test_get_top_symbols_loss(self, wb: WeightBacktest) -> None:
result = wb.get_top_symbols(n=1, kind="loss")
assert isinstance(result, list)
def test_get_top_symbols_n_exceeds(self, wb: WeightBacktest) -> None:
result = wb.get_top_symbols(n=10, kind="profit")
assert isinstance(result, list)
def test_get_symbol_daily(self, wb: WeightBacktest) -> None:
df = wb.get_symbol_daily("SYM_A")
assert isinstance(df, pd.DataFrame)
assert all(df["symbol"] == "SYM_A")
def test_get_symbol_pairs(self, wb: WeightBacktest) -> None:
df = wb.get_symbol_pairs("SYM_A")
assert isinstance(df, pd.DataFrame)