wbt 0.2.1

Weight-based backtesting engine for quantitative trading
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
"""Edge case tests for wbt metrics system.

Covers special scenarios that trigger early-return paths or boundary conditions
in daily_performance, segment_stats, long_stats/short_stats, and long_alpha_stats.
"""

from __future__ import annotations

import math
from datetime import datetime, timedelta

import numpy as np
import pandas as pd
import pytest

from wbt import WeightBacktest, daily_performance

# ============================================================================
# Helper
# ============================================================================


def _make_dfw(n_days: int, symbols: list[str], weight_fn, price_fn) -> pd.DataFrame:
    """Build deterministic test DataFrame."""
    base = datetime(2024, 1, 2, 9, 30, 0)
    rows = []
    for d in range(n_days):
        dt_str = (base + timedelta(days=d)).strftime("%Y-%m-%d %H:%M:%S")
        for sym in symbols:
            rows.append(
                {
                    "dt": dt_str,
                    "symbol": sym,
                    "weight": weight_fn(d, sym),
                    "price": price_fn(d, sym),
                }
            )
    return pd.DataFrame(rows)


# ============================================================================
# 1. daily_performance edge cases
# ============================================================================


class TestDailyPerformanceEdgeCases:
    """Edge cases for the standalone daily_performance function."""

    def test_single_value_positive(self) -> None:
        """Single positive return: std=0 → all metrics zero (by design)."""
        dp = daily_performance(np.array([0.05]), yearly_days=252)
        assert dp["绝对收益"] == 0.0
        assert dp["年化"] == 0.0

    def test_single_value_negative(self) -> None:
        """Single negative return: std=0 → all metrics zero."""
        dp = daily_performance(np.array([-0.03]), yearly_days=252)
        assert dp["绝对收益"] == 0.0

    def test_constant_positive_returns(self) -> None:
        """All same positive return: std=0 → all metrics zero."""
        dp = daily_performance(np.array([0.001] * 100), yearly_days=252)
        assert dp["绝对收益"] == 0.0
        assert dp["夏普"] == 0.0

    def test_constant_negative_returns(self) -> None:
        """All same negative return: std=0 → all metrics zero."""
        dp = daily_performance(np.array([-0.001] * 100), yearly_days=252)
        assert dp["绝对收益"] == 0.0

    def test_cum_return_near_zero(self) -> None:
        """Returns that cancel out: cum_return ≈ 0 → all metrics zero."""
        dp = daily_performance(np.array([0.01, -0.01]), yearly_days=252)
        assert dp["绝对收益"] == 0.0

    def test_two_values_with_variance(self) -> None:
        """Two different values: std > 0, cum_return > 0 → valid metrics."""
        dp = daily_performance(np.array([0.02, 0.01]), yearly_days=252)
        assert dp["绝对收益"] == pytest.approx(0.03, abs=0.001)
        assert dp["年化"] > 0
        assert dp["夏普"] > 0
        assert dp["最大回撤"] == 0.0  # all positive, no drawdown
        assert dp["日胜率"] == 1.0
        assert dp["下行波动率"] == 0.0  # no negative returns

    def test_all_positive_no_drawdown(self) -> None:
        """Monotonically positive returns: max_drawdown = 0."""
        dp = daily_performance(np.array([0.01, 0.02, 0.03, 0.04]), yearly_days=252)
        assert dp["最大回撤"] == 0.0
        assert dp["新高占比"] == 1.0
        # 新高间隔 = 最长「严格水下」天数。全正收益每天创新高,从未水下 → 0。
        assert dp["新高间隔"] == 0.0

    def test_all_negative(self) -> None:
        """All negative returns: max_drawdown > 0, win_rate = 0."""
        dp = daily_performance(np.array([-0.01, -0.02, -0.03]), yearly_days=252)
        assert dp["绝对收益"] < 0
        assert dp["年化"] < 0
        assert dp["夏普"] < 0
        assert dp["最大回撤"] > 0
        assert dp["日胜率"] == 0.0
        assert dp["下行波动率"] > 0

    def test_sharpe_upper_cap(self) -> None:
        """Very high Sharpe should be capped at 10.0."""
        # Large positive mean, tiny std
        dp = daily_performance(np.array([0.1, 0.1001]), yearly_days=252)
        assert dp["夏普"] == 10.0

    def test_sharpe_lower_cap(self) -> None:
        """Very negative Sharpe should be capped at -5.0."""
        dp = daily_performance(np.array([-0.1, -0.1001]), yearly_days=252)
        assert dp["夏普"] == -5.0

    def test_calmar_cap_when_no_drawdown(self) -> None:
        """When max_drawdown ≈ 0, calmar should be capped at 10.0."""
        dp = daily_performance(np.array([0.01, 0.02, 0.03, 0.005]), yearly_days=252)
        assert dp["最大回撤"] == 0.0
        assert dp["卡玛"] == 10.0


# ============================================================================
# 2. WeightBacktest with extreme data
# ============================================================================


class TestMinimalData:
    """Minimum viable data: 2-3 bars."""

    def test_two_bars_one_symbol(self) -> None:
        """2 bars → 1 daily return → std=0 → stats all zero."""
        dfw = _make_dfw(2, ["A"], lambda d, s: 0.5, lambda d, s: 100.0 + d)
        bt = WeightBacktest(dfw, digits=2, fee_rate=0.0, n_jobs=1, yearly_days=252)
        assert bt.stats["绝对收益"] == 0.0

    def test_three_bars_one_symbol(self) -> None:
        """3 bars → 2 daily returns → std > 0 if different → valid stats."""
        dfw = _make_dfw(3, ["A"], lambda d, s: 0.5, lambda d, s: [100.0, 102.0, 101.0][d])
        bt = WeightBacktest(dfw, digits=2, fee_rate=0.0, n_jobs=1, yearly_days=252)
        stats = bt.stats
        # 2 returns: (102-100)/100=0.02, (101-102)/102=-0.0098
        # cum ≈ 0.01, std > 0 → valid stats
        assert stats["绝对收益"] != 0.0
        assert isinstance(stats["夏普比率"], float)


class TestPureLong:
    """All weights > 0 → no short positions."""

    def test_short_stats_zero(self) -> None:
        """Pure long: short_stats should have zero return."""
        dfw = _make_dfw(10, ["A", "B"], lambda d, s: 0.3, lambda d, s: 100.0 + d * (1 if s == "A" else 0.5))
        bt = WeightBacktest(dfw, digits=2, fee_rate=0.0002, n_jobs=1, yearly_days=252)
        assert bt.stats["空头占比"] == 0.0
        assert bt.short_stats["绝对收益"] == 0.0
        assert bt.short_stats["交易次数"] == 0

    def test_long_stats_nonzero(self) -> None:
        """Pure long: long_stats should have valid return."""
        dfw = _make_dfw(10, ["A"], lambda d, s: 0.5, lambda d, s: 100.0 + d * 0.5)
        bt = WeightBacktest(dfw, digits=2, fee_rate=0.0, n_jobs=1, yearly_days=252)
        # All returns positive (price increasing), long position
        long_s = bt.long_stats
        assert long_s["绝对收益"] != 0.0 or long_s["交易次数"] == 0


class TestPureShort:
    """All weights < 0 → no long positions."""

    def test_long_stats_zero(self) -> None:
        """Pure short: long_stats should have zero return."""
        dfw = _make_dfw(10, ["A", "B"], lambda d, s: -0.3, lambda d, s: 100.0 + d * (1 if s == "A" else 0.5))
        bt = WeightBacktest(dfw, digits=2, fee_rate=0.0002, n_jobs=1, yearly_days=252)
        assert bt.stats["多头占比"] == 0.0
        assert bt.long_stats["绝对收益"] == 0.0
        assert bt.long_stats["交易次数"] == 0

    def test_short_stats_nonzero(self) -> None:
        """Pure short: short_stats should have non-zero metrics."""
        dfw = _make_dfw(10, ["A"], lambda d, s: -0.5, lambda d, s: 100.0 + d * 0.5)
        bt = WeightBacktest(dfw, digits=2, fee_rate=0.0, n_jobs=1, yearly_days=252)
        short_s = bt.short_stats
        # Short position with rising prices → negative return
        assert short_s["绝对收益"] != 0.0 or short_s["交易次数"] == 0


class TestZeroWeightsAllBars:
    """All weights = 0 → no positions at all."""

    def test_all_zero_weights(self) -> None:
        """Zero weights: all stats should be zero, no trades."""
        dfw = _make_dfw(10, ["A"], lambda d, s: 0.0, lambda d, s: 100.0 + d)
        bt = WeightBacktest(dfw, digits=2, fee_rate=0.0, n_jobs=1, yearly_days=252)
        assert bt.stats["绝对收益"] == 0.0
        assert bt.stats["交易次数"] == 0
        assert bt.stats["多头占比"] == 0.0
        assert bt.stats["空头占比"] == 0.0


# ============================================================================
# 3. segment_stats edge cases (beyond test_metrics_correctness.py)
# ============================================================================


class TestSegmentStatsEdgeCases:
    """Additional edge cases for segment_stats."""

    @pytest.fixture
    def bt(self) -> WeightBacktest:
        dfw = _make_dfw(
            20,
            ["A", "B"],
            lambda d, s: 0.3 if (d + (0 if s == "A" else 1)) % 3 != 0 else -0.2,
            lambda d, s: 100.0 + d * (0.5 if s == "A" else -0.3) + (0 if s == "A" else 50),
        )
        return WeightBacktest(dfw, digits=2, fee_rate=0.0002, n_jobs=1, yearly_days=252)

    def test_two_day_range(self, bt: WeightBacktest) -> None:
        """Two days: enough data for std > 0 → valid stats."""
        seg = bt.segment_stats(sdt=20240105, edt=20240106)
        # 2 days of data, may have non-zero std
        assert isinstance(seg["绝对收益"], (int, float))
        assert isinstance(seg["交易次数"], (int, float))

    def test_long_in_pure_short_range(self, bt: WeightBacktest) -> None:
        """kind='多头' in a range where all weights might be short → zero return."""
        # Even if some days have long weights, the long return for those days is valid
        seg = bt.segment_stats(sdt=20240102, edt=20240121, kind="多头")
        assert isinstance(seg["绝对收益"], (int, float))

    def test_segment_stats_all_three_kinds_sum(self, bt: WeightBacktest) -> None:
        """long + short should approximately equal 多空 for full range."""
        seg_all = bt.segment_stats()
        seg_long = bt.segment_stats(kind="多头")
        seg_short = bt.segment_stats(kind="空头")
        # abs_ret(多空) ≈ abs_ret(多头) + abs_ret(空头)
        combined = seg_long["绝对收益"] + seg_short["绝对收益"]
        assert seg_all["绝对收益"] == pytest.approx(combined, abs=0.001)


# ============================================================================
# 4. long_alpha_stats edge cases
# ============================================================================


class TestLongAlphaStatsEdgeCases:
    """Edge cases for vol-adjusted alpha calculation."""

    def test_pure_short_long_vol_zero(self) -> None:
        """Pure short positions: long returns are all zero → long_vol = 0 → zero alpha stats."""
        dfw = _make_dfw(10, ["A"], lambda d, s: -0.5, lambda d, s: 100.0 + d * 0.5)
        bt = WeightBacktest(dfw, digits=2, fee_rate=0.0, n_jobs=1, yearly_days=252)
        alpha = bt.long_alpha_stats
        assert alpha["绝对收益"] == 0.0
        assert alpha["夏普比率"] == 0.0
        assert alpha["最大回撤"] == 0.0

    def test_constant_prices_bench_vol_zero(self) -> None:
        """Constant prices: all returns = 0 → bench_vol = 0 → zero alpha stats."""
        dfw = _make_dfw(10, ["A"], lambda d, s: 0.5, lambda d, s: 100.0)
        bt = WeightBacktest(dfw, digits=2, fee_rate=0.0, n_jobs=1, yearly_days=252)
        alpha = bt.long_alpha_stats
        assert alpha["绝对收益"] == 0.0

    def test_zero_weights_zero_alpha(self) -> None:
        """Zero weights: long returns = 0, bench may have vol, but long_vol = 0 → zero."""
        dfw = _make_dfw(10, ["A"], lambda d, s: 0.0, lambda d, s: 100.0 + d * 0.5)
        bt = WeightBacktest(dfw, digits=2, fee_rate=0.0, n_jobs=1, yearly_days=252)
        alpha = bt.long_alpha_stats
        assert alpha["绝对收益"] == 0.0

    def test_mixed_weights_valid_alpha(self) -> None:
        """Mixed long/short with varying prices → both vols > 0 → valid alpha stats."""
        dfw = _make_dfw(
            15,
            ["A", "B"],
            lambda d, s: [0.3, 0.3, -0.2, 0.5, -0.1][d % 5],
            lambda d, s: 100.0 + d * (0.5 if s == "A" else -0.3) + math.sin(d) * 2,
        )
        bt = WeightBacktest(dfw, digits=2, fee_rate=0.0002, n_jobs=1, yearly_days=252)
        alpha = bt.long_alpha_stats
        # Should have non-trivial values (not all zero)
        assert isinstance(alpha["绝对收益"], float)
        assert isinstance(alpha["夏普比率"], float)
        # 年胜率 should be 0 (only ~15 days, < 126 threshold)
        assert alpha["年胜率"] == 0.0

    def test_alpha_keys_complete_in_zero_vol_case(self) -> None:
        """Even when vol=0, all keys should still be present."""
        dfw = _make_dfw(10, ["A"], lambda d, s: -0.5, lambda d, s: 100.0 + d * 0.5)
        bt = WeightBacktest(dfw, digits=2, fee_rate=0.0, n_jobs=1, yearly_days=252)
        alpha = bt.long_alpha_stats
        expected_keys = [
            "绝对收益",
            "年化收益",
            "夏普比率",
            "卡玛比率",
            "新高占比",
            "日胜率",
            "周胜率",
            "月胜率",
            "季胜率",
            "年胜率",
            "最大回撤",
            "年化波动率",
            "下行波动率",
            "新高间隔",
        ]
        for k in expected_keys:
            assert k in alpha, f"Missing key in zero-vol case: {k}"


# ============================================================================
# 5. Single symbol edge cases
# ============================================================================


class TestSingleSymbolMetrics:
    """Verify metrics correctness with a single symbol."""

    def test_single_symbol_stats_complete(self) -> None:
        """Single symbol should produce all expected stats keys."""
        dfw = _make_dfw(
            20, ["ONLY"], lambda d, s: [0.3, 0.3, -0.2, 0.0, 0.5][d % 5], lambda d, s: 100.0 + d * 0.5 + math.sin(d) * 2
        )
        bt = WeightBacktest(dfw, digits=2, fee_rate=0.0002, n_jobs=1, yearly_days=252)
        stats = bt.stats
        expected_keys = [
            "绝对收益",
            "年化收益",
            "夏普比率",
            "卡玛比率",
            "最大回撤",
            "日胜率",
            "周胜率",
            "月胜率",
            "季胜率",
            "年胜率",
            "交易次数",
            "年化交易次数",
            "多头占比",
            "空头占比",
            "品种数量",
        ]
        for k in expected_keys:
            assert k in stats, f"Missing key: {k}"
        assert stats["品种数量"] == 1

    def test_single_symbol_long_short_sum(self) -> None:
        """Single symbol: long_return + short_return = return."""
        dfw = _make_dfw(15, ["X"], lambda d, s: [0.3, -0.2, 0.5, 0.0, -0.4][d % 5], lambda d, s: 100.0 + d * 0.3)
        bt = WeightBacktest(dfw, digits=2, fee_rate=0.0002, n_jobs=1, yearly_days=252)
        dailys = bt.dailys
        combined = dailys["long_return"] + dailys["short_return"]
        pd.testing.assert_series_equal(dailys["return"], combined, check_names=False, atol=1e-8)


# ============================================================================
# 6. Many symbols
# ============================================================================


class TestCSMode:
    """CS (cross-sectional) weight mode: total = sum of per-symbol returns."""

    def test_cs_long_daily_return(self) -> None:
        """CS mode: total = sum (not mean) of per-symbol returns."""
        dfw = _make_dfw(
            10,
            ["A", "B"],
            lambda d, s: 0.3 if s == "A" else -0.2,
            lambda d, s: 100.0 + d * (0.5 if s == "A" else -0.3) + 50 * (s != "A"),
        )
        bt = WeightBacktest(dfw, digits=2, fee_rate=0.0002, n_jobs=1, weight_type="cs", yearly_days=252)
        dr = bt.daily_return
        sym_cols = [c for c in dr.columns if c not in ("date", "total")]
        expected = dr[sym_cols].sum(axis=1)
        np.testing.assert_allclose(dr["total"].values, expected.values, atol=1e-6)


class TestInvalidWeightType:
    """Invalid weight_type should silently default to 'ts'."""

    def test_invalid_weight_type_defaults_to_ts(self) -> None:
        """Invalid weight_type should produce same result as explicit 'ts'."""
        dfw = _make_dfw(10, ["A", "B"], lambda d, s: 0.3, lambda d, s: 100.0 + d)
        bt_ts = WeightBacktest(dfw.copy(), weight_type="ts")
        bt_inv = WeightBacktest(dfw.copy(), weight_type="INVALID")
        assert bt_ts.stats["绝对收益"] == bt_inv.stats["绝对收益"]


class TestManySymbols:
    """Verify correctness with many symbols (5+)."""

    def test_five_symbols(self) -> None:
        """5 symbols, TS mode: total = mean of per-symbol returns."""
        symbols = [f"SYM_{i}" for i in range(5)]
        dfw = _make_dfw(
            15,
            symbols,
            lambda d, s: 0.3 if int(s[-1]) % 2 == 0 else -0.2,
            lambda d, s: 100.0 + d * (int(s[-1]) + 1) * 0.1,
        )
        bt = WeightBacktest(dfw, digits=2, fee_rate=0.0002, n_jobs=1, weight_type="ts", yearly_days=252)
        assert bt.stats["品种数量"] == 5

        # Verify total = mean of symbols
        dr = bt.daily_return
        sym_cols = [c for c in dr.columns if c not in ("date", "total")]
        assert len(sym_cols) == 5
        mean_vals = dr[sym_cols].mean(axis=1)
        np.testing.assert_allclose(dr["total"].values, mean_vals.values, atol=1e-6)