use crate::core::errors::WbtError;
use crate::core::native_engine::{PairsSoA, dt_to_date_key_fast};
use polars::prelude::*;
use std::collections::BTreeMap;
#[derive(Debug, Clone, PartialEq)]
pub struct AggRow {
pub sym_id: u32,
pub open_dt: i64,
pub close_dt: i64,
pub dir: &'static str,
pub open_price: f64,
pub close_price: f64,
pub hold_bars: i64,
pub profit_bp: f64,
pub count: i64,
}
pub fn aggregate_pairs(pairs: &PairsSoA) -> Vec<AggRow> {
struct Acc {
sum_pw: f64,
sum_w: f64,
count: i64,
hold_bars: i64,
dir: &'static str,
open_price: f64,
close_price: f64,
}
let mut groups: BTreeMap<(u32, i64, i64), Acc> = BTreeMap::new();
for i in 0..pairs.sym_ids.len() {
let key = (pairs.sym_ids[i], pairs.open_dts[i], pairs.close_dts[i]);
let w = pairs.counts[i] as f64;
let entry = groups.entry(key).or_insert_with(|| Acc {
sum_pw: 0.0,
sum_w: 0.0,
count: 0,
hold_bars: pairs.hold_bars[i],
dir: pairs.dirs[i],
open_price: pairs.open_prices[i],
close_price: pairs.close_prices[i],
});
entry.sum_pw += pairs.profit_bps[i] * w;
entry.sum_w += w;
entry.count += 1;
if pairs.hold_bars[i] > entry.hold_bars {
entry.hold_bars = pairs.hold_bars[i];
}
}
groups
.into_iter()
.map(|((sym_id, open_dt, close_dt), a)| {
let profit_bp = if a.sum_w != 0.0 {
(a.sum_pw / a.sum_w * 100.0).round() / 100.0
} else {
0.0
};
AggRow {
sym_id,
open_dt,
close_dt,
dir: a.dir,
open_price: a.open_price,
close_price: a.close_price,
hold_bars: a.hold_bars,
profit_bp,
count: a.count,
}
})
.collect()
}
fn select_key_trades(
agg: &[AggRow],
top: usize,
time_unit: TimeUnit,
) -> Vec<(i32, &'static str, usize)> {
let mut by_year: BTreeMap<i32, Vec<usize>> = BTreeMap::new();
for (idx, row) in agg.iter().enumerate() {
let year = dt_to_date_key_fast(row.close_dt, time_unit) / 10000;
by_year.entry(year).or_default().push(idx);
}
let mut out: Vec<(i32, &'static str, usize)> = Vec::new();
for (year, mut idxs) in by_year {
idxs.sort_by(|&a, &b| {
agg[b]
.profit_bp
.total_cmp(&agg[a].profit_bp)
.then(agg[a].open_dt.cmp(&agg[b].open_dt))
});
let mut best: Vec<usize> = Vec::new();
for &i in idxs.iter().take(top) {
best.push(i);
out.push((year, "best", i));
}
idxs.sort_by(|&a, &b| {
agg[a]
.profit_bp
.total_cmp(&agg[b].profit_bp)
.then(agg[a].open_dt.cmp(&agg[b].open_dt))
});
for &i in idxs.iter().filter(|i| !best.contains(i)).take(top) {
out.push((year, "worst", i));
}
}
out
}
fn agg_rows_to_columns(
pairs: &PairsSoA,
rows: impl Iterator<Item = (Option<(i32, &'static str)>, AggRow)>,
) -> Result<DataFrame, WbtError> {
let mut years: Vec<i32> = Vec::new();
let mut kinds: Vec<&'static str> = Vec::new();
let mut symbols: Vec<String> = Vec::new();
let mut open_dts: Vec<i64> = Vec::new();
let mut close_dts: Vec<i64> = Vec::new();
let mut dirs: Vec<&'static str> = Vec::new();
let mut open_px: Vec<f64> = Vec::new();
let mut close_px: Vec<f64> = Vec::new();
let mut hold_bars: Vec<i64> = Vec::new();
let mut profits: Vec<f64> = Vec::new();
let mut counts: Vec<i64> = Vec::new();
let mut has_year = false;
for (meta, r) in rows {
if let Some((y, kind)) = meta {
has_year = true;
years.push(y);
kinds.push(kind);
}
symbols.push(pairs.symbol_dict[r.sym_id as usize].clone());
open_dts.push(r.open_dt);
close_dts.push(r.close_dt);
dirs.push(r.dir);
open_px.push(r.open_price);
close_px.push(r.close_price);
hold_bars.push(r.hold_bars);
profits.push(r.profit_bp);
counts.push(r.count);
}
let open_series = Series::new("开仓时间".into(), &open_dts)
.cast(&DataType::Datetime(pairs.time_unit, None))
.map_err(WbtError::Polars)?;
let close_series = Series::new("平仓时间".into(), &close_dts)
.cast(&DataType::Datetime(pairs.time_unit, None))
.map_err(WbtError::Polars)?;
let mut cols: Vec<Column> = Vec::new();
if has_year {
cols.push(Series::new("year".into(), &years).into_column());
cols.push(Series::new("kind".into(), &kinds).into_column());
}
cols.push(Series::new("symbol".into(), &symbols).into_column());
cols.push(Series::new("交易方向".into(), &dirs).into_column());
cols.push(open_series.into_column());
cols.push(close_series.into_column());
cols.push(Series::new("开仓价格".into(), &open_px).into_column());
cols.push(Series::new("平仓价格".into(), &close_px).into_column());
cols.push(Series::new("持仓K线数".into(), &hold_bars).into_column());
cols.push(Series::new("盈亏比例".into(), &profits).into_column());
cols.push(Series::new("count".into(), &counts).into_column());
DataFrame::new_infer_height(cols).map_err(WbtError::Polars)
}
pub fn agg_to_df(pairs: &PairsSoA, agg: &[AggRow]) -> Result<DataFrame, WbtError> {
agg_rows_to_columns(pairs, agg.iter().map(|r| (None, r.clone())))
}
pub fn key_trades_to_df(
pairs: &PairsSoA,
agg: &[AggRow],
top: usize,
) -> Result<DataFrame, WbtError> {
let selected = select_key_trades(agg, top, pairs.time_unit);
agg_rows_to_columns(
pairs,
selected
.into_iter()
.map(|(y, kind, i)| (Some((y, kind)), agg[i].clone())),
)
}
pub fn aggregated_pairs_df(pairs: &PairsSoA) -> Result<DataFrame, WbtError> {
agg_to_df(pairs, &aggregate_pairs(pairs))
}
pub fn key_trades_df(pairs: &PairsSoA, top: usize) -> Result<DataFrame, WbtError> {
key_trades_to_df(pairs, &aggregate_pairs(pairs), top)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_pairs(rows: &[(u32, i64, i64, i64, f64, i64)], symbol_dict: Vec<String>) -> PairsSoA {
PairsSoA {
sym_ids: rows.iter().map(|r| r.0).collect(),
dirs: rows.iter().map(|_| "多头").collect(),
open_dts: rows.iter().map(|r| r.1).collect(),
close_dts: rows.iter().map(|r| r.2).collect(),
open_prices: rows.iter().map(|_| 100.0).collect(),
close_prices: rows.iter().map(|_| 110.0).collect(),
hold_bars: rows.iter().map(|r| r.3).collect(),
event_seqs: rows.iter().map(|_| "开多->平多").collect(),
profit_bps: rows.iter().map(|r| r.4).collect(),
counts: rows.iter().map(|r| r.5).collect(),
time_unit: TimeUnit::Milliseconds,
symbol_dict,
}
}
const DT_2024: i64 = 1_717_372_800_000; const DT_2024B: i64 = 1_717_459_200_000; const DT_2025: i64 = 1_748_908_800_000;
#[test]
fn aggregate_empty_is_empty() {
let pairs = make_pairs(&[], vec!["A".into()]);
assert!(aggregate_pairs(&pairs).is_empty());
let df = aggregated_pairs_df(&pairs).unwrap();
assert_eq!(df.height(), 0);
assert!(df.column("count").is_ok());
assert!(df.column("盈亏比例").is_ok());
}
#[test]
fn aggregate_merges_same_open_close() {
let pairs = make_pairs(
&[
(0, DT_2024, DT_2024B, 5, 100.0, 1),
(0, DT_2024, DT_2024B, 5, 200.0, 3),
],
vec!["A".into()],
);
let agg = aggregate_pairs(&pairs);
assert_eq!(agg.len(), 1, "同 key 应合并为一条");
let r = &agg[0];
assert_eq!(r.count, 2, "count = 合并的记录数");
assert_eq!(r.hold_bars, 5, "hold_bars 取共享值,不求和");
assert!((r.profit_bp - 175.0).abs() < 1e-9, "got {}", r.profit_bp);
}
#[test]
fn aggregate_keeps_distinct_close() {
let pairs = make_pairs(
&[
(0, DT_2024, DT_2024B, 5, 100.0, 1),
(0, DT_2024, DT_2025, 5, 200.0, 1),
],
vec!["A".into()],
);
assert_eq!(aggregate_pairs(&pairs).len(), 2, "平仓时间不同应保留两条");
}
#[test]
fn key_trades_best_worst_per_year() {
let pairs = make_pairs(
&[
(0, DT_2024, DT_2024B, 1, 300.0, 1),
(1, DT_2024, DT_2024B, 1, -50.0, 1),
(2, DT_2024, DT_2024B, 1, 10.0, 1),
],
vec!["A".into(), "B".into(), "C".into()],
);
let agg = aggregate_pairs(&pairs);
let sel = select_key_trades(&agg, 2, TimeUnit::Milliseconds);
let best: Vec<f64> = sel
.iter()
.filter(|(_, k, _)| *k == "best")
.map(|(_, _, i)| agg[*i].profit_bp)
.collect();
let worst: Vec<f64> = sel
.iter()
.filter(|(_, k, _)| *k == "worst")
.map(|(_, _, i)| agg[*i].profit_bp)
.collect();
assert_eq!(best, vec![300.0, 10.0], "best 降序取前 2");
assert_eq!(worst, vec![-50.0], "worst 剔除已入 best 的交易");
}
#[test]
fn key_trades_top_exceeds_count() {
let pairs = make_pairs(
&[
(0, DT_2024, DT_2024B, 1, 100.0, 1),
(1, DT_2024, DT_2024B, 1, -20.0, 1),
],
vec!["A".into(), "B".into()],
);
let agg = aggregate_pairs(&pairs);
let sel = select_key_trades(&agg, 3, TimeUnit::Milliseconds);
assert_eq!(sel.iter().filter(|(_, k, _)| *k == "best").count(), 2);
assert_eq!(sel.iter().filter(|(_, k, _)| *k == "worst").count(), 0);
}
#[test]
fn key_trades_groups_by_close_year() {
let pairs = make_pairs(
&[
(0, DT_2024, DT_2024B, 1, 100.0, 1),
(1, DT_2024, DT_2025, 1, 200.0, 1),
],
vec!["A".into(), "B".into()],
);
let agg = aggregate_pairs(&pairs);
let sel = select_key_trades(&agg, 5, TimeUnit::Milliseconds);
let years: std::collections::BTreeSet<i32> = sel.iter().map(|(y, _, _)| *y).collect();
assert_eq!(years, [2024, 2025].into_iter().collect(), "按平仓年份分组");
}
#[test]
fn key_trades_df_has_year_kind_columns() {
let pairs = make_pairs(&[(0, DT_2024, DT_2024B, 1, 100.0, 1)], vec!["A".into()]);
let df = key_trades_df(&pairs, 3).unwrap();
assert!(df.column("year").is_ok());
assert!(df.column("kind").is_ok());
assert!(df.column("symbol").is_ok());
assert_eq!(df.height(), 1, "1 笔 → 仅 best 1(worst 去重后为空)");
}
}