use crate::core::errors::WbtError;
use crate::core::native_engine::PairsSoA;
use crate::core::trade_dir::TradeDir;
use crate::core::utils::RoundToNthDigit;
use serde::Serialize;
#[derive(Serialize)]
pub struct EvaluatePairs {
pub trade_direction: TradeDir,
pub trade_count: usize,
pub total_profit: f64,
pub single_trade_profit: f64,
pub win_trade_count: usize,
pub sum_win: f64,
pub win_one: f64,
pub loss_trade_count: usize,
pub sum_loss: f64,
pub loss_one: f64,
pub win_rate: f64,
pub total_profit_loss_ratio: f64,
pub single_profit_loss_ratio: f64,
pub break_even_point: f64,
pub position_k_days: f64,
}
impl Default for EvaluatePairs {
fn default() -> EvaluatePairs {
EvaluatePairs {
trade_direction: TradeDir::LongShort,
trade_count: 0,
total_profit: 0.0,
single_trade_profit: 0.0,
win_trade_count: 0,
sum_win: 0.0,
win_one: 0.0,
loss_trade_count: 0,
sum_loss: 0.0,
loss_one: 0.0,
win_rate: 0.0,
total_profit_loss_ratio: 0.0,
single_profit_loss_ratio: 0.0,
break_even_point: 0.0,
position_k_days: 0.0,
}
}
}
fn compute_break_even_point(profit_count_pairs: &mut [(f64, f64)], trade_count: f64) -> f64 {
profit_count_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let mut sum = 0.0;
let mut seen = 0.0;
let mut break_even_point = 1.0;
let mut found = false;
for (p, c) in profit_count_pairs.iter() {
if *c <= 0.0 {
continue;
}
if !found {
if *p <= 0.0 {
sum += p * c;
seen += c;
if sum >= 0.0 {
break_even_point = seen / trade_count;
found = true;
}
} else {
let need = -sum / p;
let mut k = need.ceil();
if k < 1.0 {
k = 1.0;
}
if k > *c {
k = *c;
}
sum += p * k;
seen += k;
if sum >= 0.0 {
break_even_point = seen / trade_count;
found = true;
}
if k < *c {
sum += p * (*c - k);
seen += *c - k;
}
}
} else {
sum += p * c;
seen += c;
}
}
if sum <= 0.0 { 1.0 } else { break_even_point }
}
pub fn evaluate_pairs_soa(
pairs: &PairsSoA,
trade_dir: TradeDir,
) -> Result<EvaluatePairs, WbtError> {
let n = pairs.profit_bps.len();
if n == 0 {
return Ok(EvaluatePairs::default());
}
let dir_filter = match trade_dir {
TradeDir::Long => Some("多头"),
TradeDir::Short => Some("空头"),
TradeDir::LongShort => None,
};
let mut trade_count = 0.0f64;
let mut win_trade_count = 0.0f64;
let mut sum_win = 0.0f64;
let mut loss_trade_count = 0.0f64;
let mut sum_loss = 0.0f64;
let mut sum_hold_bars = 0.0f64;
let mut profit_count_pairs: Vec<(f64, f64)> = Vec::with_capacity(n);
for i in 0..n {
if let Some(filter_str) = dir_filter
&& pairs.dirs[i] != filter_str
{
continue;
}
let p = pairs.profit_bps[i];
let c = pairs.counts[i] as f64;
if c <= 0.0 {
continue;
}
trade_count += c;
if p >= 0.0 {
win_trade_count += c;
sum_win += p * c;
} else {
loss_trade_count += c;
sum_loss += p * c;
}
sum_hold_bars += (pairs.hold_bars[i] as f64) * c;
profit_count_pairs.push((p, c));
}
if trade_count <= 0.0 {
return Ok(EvaluatePairs::default());
}
let position_k_days = sum_hold_bars / trade_count;
let win_one = if win_trade_count > 0.0 {
sum_win / win_trade_count
} else {
0.0
};
let loss_one = if loss_trade_count > 0.0 {
sum_loss / loss_trade_count
} else {
0.0
};
let win_rate = win_trade_count / trade_count;
let break_even_point = compute_break_even_point(&mut profit_count_pairs, trade_count);
let total_profit_loss_ratio = if sum_loss == 0.0 {
0.0
} else {
sum_win / sum_loss.abs()
};
let single_profit_loss_ratio = if loss_one == 0.0 {
0.0
} else {
win_one / loss_one.abs()
};
Ok(EvaluatePairs {
trade_direction: trade_dir,
trade_count: trade_count as usize,
total_profit: (sum_win + sum_loss).round_to_2_digit(),
single_trade_profit: ((sum_win + sum_loss) / trade_count).round_to_2_digit(),
win_trade_count: win_trade_count as usize,
sum_win: sum_win.round_to_2_digit(),
win_one: win_one.round_to_4_digit(),
loss_trade_count: loss_trade_count as usize,
sum_loss: sum_loss.round_to_2_digit(),
loss_one: loss_one.round_to_4_digit(),
win_rate: win_rate.round_to_4_digit(),
total_profit_loss_ratio: total_profit_loss_ratio.round_to_4_digit(),
single_profit_loss_ratio: single_profit_loss_ratio.round_to_4_digit(),
break_even_point: break_even_point.round_to_4_digit(),
position_k_days: position_k_days.round_to_2_digit(),
})
}