from __future__ import annotations
import math
from datetime import datetime, timedelta
import numpy as np
import pandas as pd
import pytest
from wbt import WeightBacktest, daily_performance
def _make_dfw(n_days: int, symbols: list[str], weight_fn, price_fn) -> pd.DataFrame:
base = datetime(2024, 1, 2, 9, 30, 0)
rows = []
for d in range(n_days):
dt_str = (base + timedelta(days=d)).strftime("%Y-%m-%d %H:%M:%S")
for sym in symbols:
rows.append(
{
"dt": dt_str,
"symbol": sym,
"weight": weight_fn(d, sym),
"price": price_fn(d, sym),
}
)
return pd.DataFrame(rows)
class TestDailyPerformanceEdgeCases:
def test_single_value_positive(self) -> None:
dp = daily_performance(np.array([0.05]), yearly_days=252)
assert dp["绝对收益"] == 0.0
assert dp["年化"] == 0.0
def test_single_value_negative(self) -> None:
dp = daily_performance(np.array([-0.03]), yearly_days=252)
assert dp["绝对收益"] == 0.0
def test_constant_positive_returns(self) -> None:
dp = daily_performance(np.array([0.001] * 100), yearly_days=252)
assert dp["绝对收益"] == 0.0
assert dp["夏普"] == 0.0
def test_constant_negative_returns(self) -> None:
dp = daily_performance(np.array([-0.001] * 100), yearly_days=252)
assert dp["绝对收益"] == 0.0
def test_cum_return_near_zero(self) -> None:
dp = daily_performance(np.array([0.01, -0.01]), yearly_days=252)
assert dp["绝对收益"] == 0.0
def test_two_values_with_variance(self) -> None:
dp = daily_performance(np.array([0.02, 0.01]), yearly_days=252)
assert dp["绝对收益"] == pytest.approx(0.03, abs=0.001)
assert dp["年化"] > 0
assert dp["夏普"] > 0
assert dp["最大回撤"] == 0.0 assert dp["日胜率"] == 1.0
assert dp["下行波动率"] == 0.0
def test_all_positive_no_drawdown(self) -> None:
dp = daily_performance(np.array([0.01, 0.02, 0.03, 0.04]), yearly_days=252)
assert dp["最大回撤"] == 0.0
assert dp["新高占比"] == 1.0
assert dp["新高间隔"] == 0.0
def test_all_negative(self) -> None:
dp = daily_performance(np.array([-0.01, -0.02, -0.03]), yearly_days=252)
assert dp["绝对收益"] < 0
assert dp["年化"] < 0
assert dp["夏普"] < 0
assert dp["最大回撤"] > 0
assert dp["日胜率"] == 0.0
assert dp["下行波动率"] > 0
def test_sharpe_upper_cap(self) -> None:
dp = daily_performance(np.array([0.1, 0.1001]), yearly_days=252)
assert dp["夏普"] == 10.0
def test_sharpe_lower_cap(self) -> None:
dp = daily_performance(np.array([-0.1, -0.1001]), yearly_days=252)
assert dp["夏普"] == -5.0
def test_calmar_cap_when_no_drawdown(self) -> None:
dp = daily_performance(np.array([0.01, 0.02, 0.03, 0.005]), yearly_days=252)
assert dp["最大回撤"] == 0.0
assert dp["卡玛"] == 10.0
class TestMinimalData:
def test_two_bars_one_symbol(self) -> None:
dfw = _make_dfw(2, ["A"], lambda d, s: 0.5, lambda d, s: 100.0 + d)
bt = WeightBacktest(dfw, digits=2, fee_rate=0.0, n_jobs=1, yearly_days=252)
assert bt.stats["绝对收益"] == 0.0
def test_three_bars_one_symbol(self) -> None:
dfw = _make_dfw(3, ["A"], lambda d, s: 0.5, lambda d, s: [100.0, 102.0, 101.0][d])
bt = WeightBacktest(dfw, digits=2, fee_rate=0.0, n_jobs=1, yearly_days=252)
stats = bt.stats
assert stats["绝对收益"] != 0.0
assert isinstance(stats["夏普比率"], float)
class TestPureLong:
def test_short_stats_zero(self) -> None:
dfw = _make_dfw(10, ["A", "B"], lambda d, s: 0.3, lambda d, s: 100.0 + d * (1 if s == "A" else 0.5))
bt = WeightBacktest(dfw, digits=2, fee_rate=0.0002, n_jobs=1, yearly_days=252)
assert bt.stats["空头占比"] == 0.0
assert bt.short_stats["绝对收益"] == 0.0
assert bt.short_stats["交易次数"] == 0
def test_long_stats_nonzero(self) -> None:
dfw = _make_dfw(10, ["A"], lambda d, s: 0.5, lambda d, s: 100.0 + d * 0.5)
bt = WeightBacktest(dfw, digits=2, fee_rate=0.0, n_jobs=1, yearly_days=252)
long_s = bt.long_stats
assert long_s["绝对收益"] != 0.0 or long_s["交易次数"] == 0
class TestPureShort:
def test_long_stats_zero(self) -> None:
dfw = _make_dfw(10, ["A", "B"], lambda d, s: -0.3, lambda d, s: 100.0 + d * (1 if s == "A" else 0.5))
bt = WeightBacktest(dfw, digits=2, fee_rate=0.0002, n_jobs=1, yearly_days=252)
assert bt.stats["多头占比"] == 0.0
assert bt.long_stats["绝对收益"] == 0.0
assert bt.long_stats["交易次数"] == 0
def test_short_stats_nonzero(self) -> None:
dfw = _make_dfw(10, ["A"], lambda d, s: -0.5, lambda d, s: 100.0 + d * 0.5)
bt = WeightBacktest(dfw, digits=2, fee_rate=0.0, n_jobs=1, yearly_days=252)
short_s = bt.short_stats
assert short_s["绝对收益"] != 0.0 or short_s["交易次数"] == 0
class TestZeroWeightsAllBars:
def test_all_zero_weights(self) -> None:
dfw = _make_dfw(10, ["A"], lambda d, s: 0.0, lambda d, s: 100.0 + d)
bt = WeightBacktest(dfw, digits=2, fee_rate=0.0, n_jobs=1, yearly_days=252)
assert bt.stats["绝对收益"] == 0.0
assert bt.stats["交易次数"] == 0
assert bt.stats["多头占比"] == 0.0
assert bt.stats["空头占比"] == 0.0
class TestSegmentStatsEdgeCases:
@pytest.fixture
def bt(self) -> WeightBacktest:
dfw = _make_dfw(
20,
["A", "B"],
lambda d, s: 0.3 if (d + (0 if s == "A" else 1)) % 3 != 0 else -0.2,
lambda d, s: 100.0 + d * (0.5 if s == "A" else -0.3) + (0 if s == "A" else 50),
)
return WeightBacktest(dfw, digits=2, fee_rate=0.0002, n_jobs=1, yearly_days=252)
def test_two_day_range(self, bt: WeightBacktest) -> None:
seg = bt.segment_stats(sdt=20240105, edt=20240106)
assert isinstance(seg["绝对收益"], (int, float))
assert isinstance(seg["交易次数"], (int, float))
def test_long_in_pure_short_range(self, bt: WeightBacktest) -> None:
seg = bt.segment_stats(sdt=20240102, edt=20240121, kind="多头")
assert isinstance(seg["绝对收益"], (int, float))
def test_segment_stats_all_three_kinds_sum(self, bt: WeightBacktest) -> None:
seg_all = bt.segment_stats()
seg_long = bt.segment_stats(kind="多头")
seg_short = bt.segment_stats(kind="空头")
combined = seg_long["绝对收益"] + seg_short["绝对收益"]
assert seg_all["绝对收益"] == pytest.approx(combined, abs=0.001)
class TestLongAlphaStatsEdgeCases:
def test_pure_short_long_vol_zero(self) -> None:
dfw = _make_dfw(10, ["A"], lambda d, s: -0.5, lambda d, s: 100.0 + d * 0.5)
bt = WeightBacktest(dfw, digits=2, fee_rate=0.0, n_jobs=1, yearly_days=252)
alpha = bt.long_alpha_stats
assert alpha["绝对收益"] == 0.0
assert alpha["夏普比率"] == 0.0
assert alpha["最大回撤"] == 0.0
def test_constant_prices_bench_vol_zero(self) -> None:
dfw = _make_dfw(10, ["A"], lambda d, s: 0.5, lambda d, s: 100.0)
bt = WeightBacktest(dfw, digits=2, fee_rate=0.0, n_jobs=1, yearly_days=252)
alpha = bt.long_alpha_stats
assert alpha["绝对收益"] == 0.0
def test_zero_weights_zero_alpha(self) -> None:
dfw = _make_dfw(10, ["A"], lambda d, s: 0.0, lambda d, s: 100.0 + d * 0.5)
bt = WeightBacktest(dfw, digits=2, fee_rate=0.0, n_jobs=1, yearly_days=252)
alpha = bt.long_alpha_stats
assert alpha["绝对收益"] == 0.0
def test_mixed_weights_valid_alpha(self) -> None:
dfw = _make_dfw(
15,
["A", "B"],
lambda d, s: [0.3, 0.3, -0.2, 0.5, -0.1][d % 5],
lambda d, s: 100.0 + d * (0.5 if s == "A" else -0.3) + math.sin(d) * 2,
)
bt = WeightBacktest(dfw, digits=2, fee_rate=0.0002, n_jobs=1, yearly_days=252)
alpha = bt.long_alpha_stats
assert isinstance(alpha["绝对收益"], float)
assert isinstance(alpha["夏普比率"], float)
assert alpha["年胜率"] == 0.0
def test_alpha_keys_complete_in_zero_vol_case(self) -> None:
dfw = _make_dfw(10, ["A"], lambda d, s: -0.5, lambda d, s: 100.0 + d * 0.5)
bt = WeightBacktest(dfw, digits=2, fee_rate=0.0, n_jobs=1, yearly_days=252)
alpha = bt.long_alpha_stats
expected_keys = [
"绝对收益",
"年化收益",
"夏普比率",
"卡玛比率",
"新高占比",
"日胜率",
"周胜率",
"月胜率",
"季胜率",
"年胜率",
"最大回撤",
"年化波动率",
"下行波动率",
"新高间隔",
]
for k in expected_keys:
assert k in alpha, f"Missing key in zero-vol case: {k}"
class TestSingleSymbolMetrics:
def test_single_symbol_stats_complete(self) -> None:
dfw = _make_dfw(
20, ["ONLY"], lambda d, s: [0.3, 0.3, -0.2, 0.0, 0.5][d % 5], lambda d, s: 100.0 + d * 0.5 + math.sin(d) * 2
)
bt = WeightBacktest(dfw, digits=2, fee_rate=0.0002, n_jobs=1, yearly_days=252)
stats = bt.stats
expected_keys = [
"绝对收益",
"年化收益",
"夏普比率",
"卡玛比率",
"最大回撤",
"日胜率",
"周胜率",
"月胜率",
"季胜率",
"年胜率",
"交易次数",
"年化交易次数",
"多头占比",
"空头占比",
"品种数量",
]
for k in expected_keys:
assert k in stats, f"Missing key: {k}"
assert stats["品种数量"] == 1
def test_single_symbol_long_short_sum(self) -> None:
dfw = _make_dfw(15, ["X"], lambda d, s: [0.3, -0.2, 0.5, 0.0, -0.4][d % 5], lambda d, s: 100.0 + d * 0.3)
bt = WeightBacktest(dfw, digits=2, fee_rate=0.0002, n_jobs=1, yearly_days=252)
dailys = bt.dailys
combined = dailys["long_return"] + dailys["short_return"]
pd.testing.assert_series_equal(dailys["return"], combined, check_names=False, atol=1e-8)
class TestCSMode:
def test_cs_long_daily_return(self) -> None:
dfw = _make_dfw(
10,
["A", "B"],
lambda d, s: 0.3 if s == "A" else -0.2,
lambda d, s: 100.0 + d * (0.5 if s == "A" else -0.3) + 50 * (s != "A"),
)
bt = WeightBacktest(dfw, digits=2, fee_rate=0.0002, n_jobs=1, weight_type="cs", yearly_days=252)
dr = bt.daily_return
sym_cols = [c for c in dr.columns if c not in ("date", "total")]
expected = dr[sym_cols].sum(axis=1)
np.testing.assert_allclose(dr["total"].values, expected.values, atol=1e-6)
class TestInvalidWeightType:
def test_invalid_weight_type_defaults_to_ts(self) -> None:
dfw = _make_dfw(10, ["A", "B"], lambda d, s: 0.3, lambda d, s: 100.0 + d)
bt_ts = WeightBacktest(dfw.copy(), weight_type="ts")
bt_inv = WeightBacktest(dfw.copy(), weight_type="INVALID")
assert bt_ts.stats["绝对收益"] == bt_inv.stats["绝对收益"]
class TestManySymbols:
def test_five_symbols(self) -> None:
symbols = [f"SYM_{i}" for i in range(5)]
dfw = _make_dfw(
15,
symbols,
lambda d, s: 0.3 if int(s[-1]) % 2 == 0 else -0.2,
lambda d, s: 100.0 + d * (int(s[-1]) + 1) * 0.1,
)
bt = WeightBacktest(dfw, digits=2, fee_rate=0.0002, n_jobs=1, weight_type="ts", yearly_days=252)
assert bt.stats["品种数量"] == 5
dr = bt.daily_return
sym_cols = [c for c in dr.columns if c not in ("date", "total")]
assert len(sym_cols) == 5
mean_vals = dr[sym_cols].mean(axis=1)
np.testing.assert_allclose(dr["total"].values, mean_vals.values, atol=1e-6)