from __future__ import annotations
import os
from pathlib import Path
import pandas as pd
import pytest
from wbt import WeightBacktest, generate_backtest_report
from wbt.report import HtmlReportBuilder
from wbt.report._generator import (
_normalize_stats_for_czsc_view,
_validate_input_data,
)
from wbt.report._plot_backtest import get_performance_metrics_cards
def test_validate_rejects_missing_columns(sample_dfw: pd.DataFrame) -> None:
bad = sample_dfw.drop(columns=["price"])
with pytest.raises(ValueError, match="缺少必需列"):
_validate_input_data(bad)
def test_validate_rejects_empty_dataframe() -> None:
empty = pd.DataFrame({"dt": [], "symbol": [], "weight": [], "price": []})
with pytest.raises(ValueError, match="不能为空"):
_validate_input_data(empty)
def test_validate_rejects_nan_in_weight_or_price(sample_dfw: pd.DataFrame) -> None:
bad = sample_dfw.copy()
bad.loc[0, "weight"] = float("nan")
with pytest.raises(ValueError, match="不能包含空值"):
_validate_input_data(bad)
def test_normalize_adds_czsc_aliases() -> None:
wbt_stats = {"年化收益": 0.15, "夏普比率": 1.2, "卡玛比率": 0.8, "其他": "x"}
out = _normalize_stats_for_czsc_view(wbt_stats)
assert out["年化"] == 0.15
assert out["夏普"] == 1.2
assert out["卡玛"] == 0.8
assert out["其他"] == "x"
assert out["年化收益"] == 0.15
def test_normalize_preserves_existing_short_keys() -> None:
stats = {"年化": 0.05, "年化收益": 0.10}
out = _normalize_stats_for_czsc_view(stats)
assert out["年化"] == 0.05
def test_normalize_does_not_mutate_input() -> None:
stats = {"年化收益": 0.15}
out = _normalize_stats_for_czsc_view(stats)
assert "年化" not in stats assert "年化" in out
def test_metrics_cards_on_real_wbt_stats(wb: WeightBacktest) -> None:
normalized = _normalize_stats_for_czsc_view(wb.stats)
cards = get_performance_metrics_cards(normalized)
assert isinstance(cards, list)
assert len(cards) == 11
labels = {c["label"] for c in cards}
assert labels == {
"年化收益率",
"单笔收益(BP)",
"交易胜率",
"持仓K线数",
"最大回撤",
"年化",
"夏普",
"卡玛",
"年化波动率",
"多头占比",
"空头占比",
}
for c in cards:
assert isinstance(c["value"], str)
assert isinstance(c["is_positive"], bool)
def test_html_builder_render_contains_title_and_basic_structure() -> None:
html = (
HtmlReportBuilder(title="My Test Report")
.add_header({"Period": "2020-2024"}, subtitle="Hello")
.add_metrics([{"label": "Return", "value": "15%", "is_positive": True}])
.add_footer("End of report")
.render()
)
assert "<!DOCTYPE html>" in html
assert "<title>My Test Report</title>" in html
assert "My Test Report" in html
assert "Period: 2020-2024" in html
assert "Hello" in html
assert "Return" in html
assert "15%" in html
assert "End of report" in html
def test_html_builder_save_writes_file(tmp_path: Path) -> None:
out = tmp_path / "subdir" / "report.html"
builder = HtmlReportBuilder(title="x")
builder.add_header({"a": "b"})
builder.add_footer()
returned = builder.save(str(out))
assert returned == str(out)
assert out.exists()
assert out.read_text(encoding="utf-8").startswith("<!DOCTYPE html>")
def test_generate_backtest_report_writes_valid_html(sample_dfw: pd.DataFrame, tmp_path: Path) -> None:
out_path = tmp_path / "backtest_report.html"
returned = generate_backtest_report(sample_dfw, output_path=str(out_path), title="测试报告", n_jobs=1)
assert returned == str(out_path)
assert out_path.exists()
html = out_path.read_text(encoding="utf-8")
assert "<!DOCTYPE html>" in html
assert "<title>测试报告</title>" in html
assert "header-section" in html
assert "section-title" in html
assert 'class="nav nav-tabs"' in html
assert 'class="chart-card"' in html
assert "回测统计" in html
assert "多空对比" in html
assert "plotly-graph-div" in html
assert "wbt 权重回测引擎" in html
def test_generate_backtest_report_default_path_in_cwd(
sample_dfw: pd.DataFrame, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.chdir(tmp_path)
returned = generate_backtest_report(sample_dfw, n_jobs=1)
expected = tmp_path / "backtest_report.html"
assert os.path.abspath(returned) == str(expected)
assert expected.exists()
def test_generate_backtest_report_rejects_bad_input(tmp_path: Path) -> None:
bad = pd.DataFrame({"dt": [], "symbol": [], "weight": []}) with pytest.raises(ValueError):
generate_backtest_report(bad, output_path=str(tmp_path / "x.html"))
def test_generate_backtest_report_contains_all_11_metric_labels(sample_dfw: pd.DataFrame, tmp_path: Path) -> None:
out = tmp_path / "r.html"
generate_backtest_report(sample_dfw, output_path=str(out), n_jobs=1)
html = out.read_text(encoding="utf-8")
for label in [
"年化收益率",
"单笔收益(BP)",
"交易胜率",
"持仓K线数",
"最大回撤",
"年化波动率",
"多头占比",
"空头占比",
]:
assert label in html, f"missing metric label: {label}"