from __future__ import annotations
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from ._common import (
COLOR_DRAWDOWN,
COLOR_RETURN,
COLOR_TOTAL,
apply_default_layout,
figure_to_html,
)
def plot_backtest_overview(
daily_return: pd.DataFrame,
col: str = "total",
title: str | None = "回测概览",
to_html: bool = False,
) -> go.Figure | str:
fig = make_subplots(
rows=2,
cols=2,
specs=[
[{"secondary_y": True}, {"secondary_y": False}],
[{"colspan": 2}, None],
],
subplot_titles=("回撤 & 累计收益", "日收益分布", "月度收益热力图"),
vertical_spacing=0.12,
horizontal_spacing=0.08,
)
if daily_return.empty or col not in daily_return.columns:
apply_default_layout(fig, title=title, height=700)
return figure_to_html(fig) if to_html else fig
df = daily_return.copy()
df["date"] = pd.to_datetime(df["date"])
df = df.sort_values("date").reset_index(drop=True)
cumsum = df[col].cumsum()
running_max = cumsum.cummax()
drawdown = cumsum - running_max
fig.add_trace(
go.Scatter(
x=df["date"],
y=drawdown,
fill="tozeroy",
fillcolor=COLOR_DRAWDOWN,
line={"color": "rgba(255,59,59,0.6)", "width": 1},
name="回撤",
),
row=1,
col=1,
secondary_y=False,
)
fig.add_trace(
go.Scatter(
x=df["date"],
y=cumsum,
mode="lines",
line={"color": COLOR_TOTAL, "width": 1.5},
name="累计收益",
),
row=1,
col=1,
secondary_y=True,
)
series = df[col].dropna() * 100
float(series.mean())
float(series.std())
fig.add_trace(
go.Histogram(
x=series,
nbinsx=40,
marker_color=COLOR_RETURN,
opacity=0.7,
name="日收益",
showlegend=False,
),
row=1,
col=2,
)
df["year"] = df["date"].dt.year
df["month"] = df["date"].dt.month
pivot = df.groupby(["year", "month"])[col].sum().unstack(fill_value=0)
years = pivot.index.tolist()
months = [str(m) for m in pivot.columns.tolist()]
z = pivot.values.tolist()
text = [[f"{v * 100:.2f}%" for v in row] for row in z]
fig.add_trace(
go.Heatmap(
x=months,
y=[str(y) for y in years],
z=z,
text=text,
texttemplate="%{text}",
colorscale="RdYlGn",
zmid=0,
showscale=False,
),
row=2,
col=1,
)
fig.update_layout(
template="plotly_white",
height=700,
title=title,
title_font_size=14,
margin={"l": 60, "r": 40, "t": 80, "b": 60},
)
return figure_to_html(fig) if to_html else fig
def plot_colored_table(
stats: dict,
title: str | None = "绩效指标",
to_html: bool = False,
) -> go.Figure | str:
fig = go.Figure()
if not stats:
apply_default_layout(fig, title=title)
return figure_to_html(fig) if to_html else fig
keys = list(stats.keys())
values = [stats[k] for k in keys]
def _fmt(v: object) -> str:
if isinstance(v, float):
return f"{v:.4f}"
return str(v)
def _cell_color(v: object) -> str:
if isinstance(v, (int, float)):
if v > 0:
return "rgba(231,76,60,0.12)"
if v < 0:
return "rgba(46,204,113,0.12)"
return "white"
fmt_values = [_fmt(v) for v in values]
cell_colors = [_cell_color(v) for v in values]
fig.add_trace(
go.Table(
header={
"values": ["指标", "数值"],
"fill_color": "#3498db",
"font_color": "white",
"align": "center",
"font_size": 13,
},
cells={
"values": [keys, fmt_values],
"fill_color": [["white"] * len(keys), cell_colors],
"align": ["left", "right"],
"font_size": 12,
},
)
)
apply_default_layout(fig, title=title, height=max(400, 30 * len(keys) + 100))
return figure_to_html(fig) if to_html else fig
def plot_long_short_comparison(
daily_return: pd.DataFrame,
stats: dict,
long_stats: dict,
short_stats: dict,
title: str | None = "多空对比",
to_html: bool = False,
) -> go.Figure | str:
fig = make_subplots(
rows=1,
cols=2,
column_widths=[0.6, 0.4],
specs=[[{"type": "xy"}, {"type": "table"}]],
subplot_titles=("累计收益曲线", "关键指标对比"),
)
if not daily_return.empty:
df = daily_return.copy()
df["date"] = pd.to_datetime(df["date"])
df = df.sort_values("date").reset_index(drop=True)
for col, color, name in [
("total", COLOR_TOTAL, "多空"),
]:
if col in df.columns:
fig.add_trace(
go.Scatter(
x=df["date"],
y=df[col].cumsum(),
mode="lines",
line={"color": color, "width": 1.5},
name=name,
),
row=1,
col=1,
)
_metric_keys = ["年化", "夏普", "最大回撤", "卡玛", "日胜率", "绝对收益"]
def _get(d: dict, k: str) -> str:
v = d.get(k, "N/A")
if isinstance(v, float):
return f"{v:.4f}"
return str(v)
metric_names = [k for k in _metric_keys if k in stats or k in long_stats or k in short_stats]
col_total = [_get(stats, k) for k in metric_names]
col_long = [_get(long_stats, k) for k in metric_names]
col_short = [_get(short_stats, k) for k in metric_names]
fig.add_trace(
go.Table(
header={
"values": ["指标", "多空", "多头", "空头"],
"fill_color": "#3498db",
"font_color": "white",
"align": "center",
"font_size": 12,
},
cells={
"values": [metric_names, col_total, col_long, col_short],
"align": ["left", "right", "right", "right"],
"font_size": 11,
},
),
row=1,
col=2,
)
fig.update_layout(
template="plotly_white",
height=500,
title=title,
title_font_size=14,
margin={"l": 60, "r": 40, "t": 80, "b": 60},
)
fig.update_yaxes(tickformat=".1%", row=1, col=1)
return figure_to_html(fig) if to_html else fig