use crate::core::errors::WbtError;
use crate::core::native_engine::PairsSoA;
use crate::core::trade_dir::TradeDir;
use crate::core::utils::RoundToNthDigit;
use serde::Serialize;
#[derive(Clone, Serialize)]
pub struct EvaluatePairs {
pub trade_direction: TradeDir,
pub trade_count: usize,
pub total_profit: f64,
pub single_trade_profit: f64,
pub win_trade_count: usize,
pub sum_win: f64,
pub win_one: f64,
pub loss_trade_count: usize,
pub sum_loss: f64,
pub loss_one: f64,
pub win_rate: f64,
pub total_profit_loss_ratio: f64,
pub single_profit_loss_ratio: f64,
pub break_even_point: f64,
pub position_k_days: f64,
}
impl Default for EvaluatePairs {
fn default() -> EvaluatePairs {
EvaluatePairs {
trade_direction: TradeDir::LongShort,
trade_count: 0,
total_profit: 0.0,
single_trade_profit: 0.0,
win_trade_count: 0,
sum_win: 0.0,
win_one: 0.0,
loss_trade_count: 0,
sum_loss: 0.0,
loss_one: 0.0,
win_rate: 0.0,
total_profit_loss_ratio: 0.0,
single_profit_loss_ratio: 0.0,
break_even_point: 0.0,
position_k_days: 0.0,
}
}
}
fn compute_break_even_point(profit_count_pairs: &mut [(f64, f64)], trade_count: f64) -> f64 {
profit_count_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let mut sum = 0.0;
let mut seen = 0.0;
let mut break_even_point = 1.0;
let mut found = false;
for (p, c) in profit_count_pairs.iter() {
if *c <= 0.0 {
continue;
}
if !found {
if *p <= 0.0 {
sum += p * c;
seen += c;
if sum >= 0.0 {
break_even_point = seen / trade_count;
found = true;
}
} else {
let need = -sum / p;
let mut k = need.ceil();
if k < 1.0 {
k = 1.0;
}
if k > *c {
k = *c;
}
sum += p * k;
seen += k;
if sum >= 0.0 {
break_even_point = seen / trade_count;
found = true;
}
if k < *c {
sum += p * (*c - k);
seen += *c - k;
}
}
} else {
sum += p * c;
seen += c;
}
}
if sum <= 0.0 { 1.0 } else { break_even_point }
}
pub fn evaluate_pairs_soa(
pairs: &PairsSoA,
trade_dir: TradeDir,
) -> Result<EvaluatePairs, WbtError> {
let n = pairs.profit_bps.len();
if n == 0 {
return Ok(EvaluatePairs::default());
}
let dir_filter = match trade_dir {
TradeDir::Long => Some("多头"),
TradeDir::Short => Some("空头"),
TradeDir::LongShort => None,
};
let mut trade_count = 0.0f64;
let mut win_trade_count = 0.0f64;
let mut sum_win = 0.0f64;
let mut loss_trade_count = 0.0f64;
let mut sum_loss = 0.0f64;
let mut sum_hold_bars = 0.0f64;
let mut profit_count_pairs: Vec<(f64, f64)> = Vec::with_capacity(n);
for i in 0..n {
if let Some(filter_str) = dir_filter
&& pairs.dirs[i] != filter_str
{
continue;
}
let p = pairs.profit_bps[i];
let c = pairs.counts[i] as f64;
if c <= 0.0 {
continue;
}
trade_count += c;
if p >= 0.0 {
win_trade_count += c;
sum_win += p * c;
} else {
loss_trade_count += c;
sum_loss += p * c;
}
sum_hold_bars += (pairs.hold_bars[i] as f64) * c;
profit_count_pairs.push((p, c));
}
if trade_count <= 0.0 {
return Ok(EvaluatePairs::default());
}
let position_k_days = sum_hold_bars / trade_count;
let win_one = if win_trade_count > 0.0 {
sum_win / win_trade_count
} else {
0.0
};
let loss_one = if loss_trade_count > 0.0 {
sum_loss / loss_trade_count
} else {
0.0
};
let win_rate = win_trade_count / trade_count;
let break_even_point = compute_break_even_point(&mut profit_count_pairs, trade_count);
let total_profit_loss_ratio = if sum_loss == 0.0 {
0.0
} else {
sum_win / sum_loss.abs()
};
let single_profit_loss_ratio = if loss_one == 0.0 {
0.0
} else {
win_one / loss_one.abs()
};
Ok(EvaluatePairs {
trade_direction: trade_dir,
trade_count: trade_count as usize,
total_profit: (sum_win + sum_loss).round_to_2_digit(),
single_trade_profit: ((sum_win + sum_loss) / trade_count).round_to_2_digit(),
win_trade_count: win_trade_count as usize,
sum_win: sum_win.round_to_2_digit(),
win_one: win_one.round_to_4_digit(),
loss_trade_count: loss_trade_count as usize,
sum_loss: sum_loss.round_to_2_digit(),
loss_one: loss_one.round_to_4_digit(),
win_rate: win_rate.round_to_4_digit(),
total_profit_loss_ratio: total_profit_loss_ratio.round_to_4_digit(),
single_profit_loss_ratio: single_profit_loss_ratio.round_to_4_digit(),
break_even_point: break_even_point.round_to_4_digit(),
position_k_days: position_k_days.round_to_2_digit(),
})
}
#[cfg(test)]
mod tests {
use super::*;
use polars::prelude::TimeUnit;
fn make_pairs(
profit_bps: &[f64],
counts: &[i64],
hold_bars: &[i64],
dirs: &[&'static str],
) -> PairsSoA {
let n = profit_bps.len();
PairsSoA {
sym_ids: vec![0; n],
dirs: dirs.to_vec(),
open_dts: vec![0; n],
close_dts: vec![1000; n],
open_prices: vec![100.0; n],
close_prices: vec![101.0; n],
hold_bars: hold_bars.to_vec(),
event_seqs: vec!["开多 -> 平多"; n],
profit_bps: profit_bps.to_vec(),
counts: counts.to_vec(),
time_unit: TimeUnit::Milliseconds,
symbol_dict: vec!["SYM0".into()],
}
}
#[test]
fn evaluate_empty_pairs() {
let pairs = make_pairs(&[], &[], &[], &[]);
let ep = evaluate_pairs_soa(&pairs, TradeDir::LongShort).unwrap();
assert_eq!(ep.trade_count, 0);
}
#[test]
fn evaluate_all_win() {
let pairs = make_pairs(&[100.0, 50.0], &[1, 1], &[10, 5], &["多头", "多头"]);
let ep = evaluate_pairs_soa(&pairs, TradeDir::LongShort).unwrap();
assert_eq!(ep.trade_count, 2);
assert_eq!(ep.win_trade_count, 2);
assert_eq!(ep.loss_trade_count, 0);
assert_eq!(ep.win_rate, 1.0);
assert_eq!(ep.total_profit, 150.0);
assert_eq!(ep.single_trade_profit, 75.0);
assert_eq!(ep.sum_win, 150.0);
assert_eq!(ep.win_one, 75.0);
assert_eq!(ep.position_k_days, 7.5);
assert_eq!(ep.sum_loss, 0.0);
assert_eq!(ep.total_profit_loss_ratio, 0.0);
}
#[test]
fn evaluate_all_loss() {
let pairs = make_pairs(&[-100.0, -50.0], &[1, 1], &[10, 5], &["空头", "空头"]);
let ep = evaluate_pairs_soa(&pairs, TradeDir::LongShort).unwrap();
assert_eq!(ep.trade_count, 2);
assert_eq!(ep.loss_trade_count, 2);
assert_eq!(ep.win_rate, 0.0);
assert_eq!(ep.sum_loss, -150.0);
assert_eq!(ep.loss_one, -75.0);
assert_eq!(ep.total_profit, -150.0);
assert_eq!(ep.single_trade_profit, -75.0);
}
#[test]
fn evaluate_mixed() {
let pairs = make_pairs(
&[100.0, -50.0, 200.0],
&[2, 1, 3],
&[10, 5, 20],
&["多头", "空头", "多头"],
);
let ep = evaluate_pairs_soa(&pairs, TradeDir::LongShort).unwrap();
assert_eq!(ep.trade_count, 6);
assert_eq!(ep.win_trade_count, 5);
assert_eq!(ep.loss_trade_count, 1);
assert_eq!(ep.win_rate, 0.8333);
assert_eq!(ep.total_profit, 750.0);
assert_eq!(ep.single_trade_profit, 125.0);
assert_eq!(ep.sum_win, 800.0);
assert_eq!(ep.win_one, 160.0);
assert_eq!(ep.sum_loss, -50.0);
assert_eq!(ep.loss_one, -50.0);
assert_eq!(ep.total_profit_loss_ratio, 16.0);
assert_eq!(ep.single_profit_loss_ratio, 3.2);
assert_eq!(ep.position_k_days, 14.17);
}
#[test]
fn evaluate_direction_filter_long() {
let pairs = make_pairs(&[100.0, -50.0], &[1, 1], &[10, 5], &["多头", "空头"]);
let ep = evaluate_pairs_soa(&pairs, TradeDir::Long).unwrap();
assert_eq!(ep.trade_count, 1);
assert_eq!(ep.win_trade_count, 1);
}
#[test]
fn evaluate_direction_filter_short() {
let pairs = make_pairs(&[100.0, -50.0], &[1, 1], &[10, 5], &["多头", "空头"]);
let ep = evaluate_pairs_soa(&pairs, TradeDir::Short).unwrap();
assert_eq!(ep.trade_count, 1);
assert_eq!(ep.loss_trade_count, 1);
}
#[test]
fn break_even_all_profit() {
let mut pairs = vec![(100.0, 1.0), (50.0, 1.0)];
let bep = compute_break_even_point(&mut pairs, 2.0);
assert!(bep < 1.0);
}
#[test]
fn break_even_all_loss() {
let mut pairs = vec![(-100.0, 1.0), (-50.0, 1.0)];
let bep = compute_break_even_point(&mut pairs, 2.0);
assert_eq!(bep, 1.0);
}
#[test]
fn break_even_zero_count_skipped() {
let mut pairs = vec![(100.0, 0.0), (-50.0, 1.0), (200.0, 1.0)];
let bep = compute_break_even_point(&mut pairs, 2.0);
assert_eq!(bep, 1.0);
}
}