use std::collections::HashMap;
pub fn trade_stats(pnl: &[f64], hold_bars: &[f64]) -> (f64, f64, f64, f64, f64) {
let n = pnl.len();
assert!(n > 0, "pnl must be non-empty");
assert_eq!(
n,
hold_bars.len(),
"pnl and hold_bars must have equal length"
);
let mut wins: Vec<f64> = Vec::new();
let mut losses: Vec<f64> = Vec::new();
for &v in pnl.iter() {
if v > 0.0 {
wins.push(v);
} else if v < 0.0 {
losses.push(v);
}
}
let win_rate = wins.len() as f64 / n as f64;
let avg_win = if wins.is_empty() {
0.0
} else {
wins.iter().sum::<f64>() / wins.len() as f64
};
let avg_loss = if losses.is_empty() {
0.0
} else {
losses.iter().sum::<f64>() / losses.len() as f64
};
let gross_profit: f64 = wins.iter().sum();
let gross_loss: f64 = losses.iter().map(|v| v.abs()).sum();
let profit_factor = if gross_loss == 0.0 {
f64::INFINITY
} else {
gross_profit / gross_loss
};
let avg_hold = hold_bars.iter().sum::<f64>() / n as f64;
(win_rate, avg_win, avg_loss, profit_factor, avg_hold)
}
pub fn monthly_contribution(bar_returns: &[f64], month_index: &[i64]) -> (Vec<i64>, Vec<f64>) {
let n = bar_returns.len();
assert_eq!(
n,
month_index.len(),
"bar_returns and month_index must have equal length"
);
let mut map: HashMap<i64, f64> = HashMap::new();
for i in 0..n {
if !bar_returns[i].is_nan() {
*map.entry(month_index[i]).or_insert(0.0) += bar_returns[i];
}
}
let mut months: Vec<i64> = map.keys().copied().collect();
months.sort_unstable();
let contributions: Vec<f64> = months.iter().map(|m| map[m]).collect();
(months, contributions)
}
pub fn signal_attribution(bar_returns: &[f64], signal_labels: &[i64]) -> (Vec<i64>, Vec<f64>) {
let n = bar_returns.len();
assert_eq!(
n,
signal_labels.len(),
"bar_returns and signal_labels must have equal length"
);
let mut map: HashMap<i64, f64> = HashMap::new();
for i in 0..n {
if !bar_returns[i].is_nan() {
*map.entry(signal_labels[i]).or_insert(0.0) += bar_returns[i];
}
}
let mut labels: Vec<i64> = map.keys().copied().collect();
labels.sort_unstable();
let contributions: Vec<f64> = labels.iter().map(|l| map[l]).collect();
(labels, contributions)
}
pub fn extract_trades(positions: &[f64], strategy_returns: &[f64]) -> (Vec<f64>, Vec<f64>) {
let n = positions.len();
assert_eq!(
n,
strategy_returns.len(),
"positions and strategy_returns must have equal length"
);
let mut pnl = Vec::<f64>::new();
let mut hold = Vec::<f64>::new();
let mut i = 0usize;
while i < n {
if positions[i] == 0.0 {
i += 1;
continue;
}
let mut j = i + 1;
while j < n && positions[j] == positions[i] {
j += 1;
}
let mut trade_pnl = 0.0_f64;
for v in strategy_returns.iter().take(j).skip(i) {
trade_pnl += *v;
}
pnl.push(trade_pnl);
hold.push((j - i) as f64);
i = j;
}
(pnl, hold)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trade_stats_basic() {
let pnl = [100.0, -50.0, 200.0, -30.0, 150.0];
let hold = [5.0, 3.0, 7.0, 2.0, 6.0];
let (wr, aw, al, pf, ah) = trade_stats(&pnl, &hold);
assert!((wr - 0.6).abs() < 1e-10);
assert!((aw - 150.0).abs() < 1e-10);
assert!((al - (-40.0)).abs() < 1e-10);
assert!((pf - 5.625).abs() < 1e-10);
assert!((ah - 4.6).abs() < 1e-10);
}
#[test]
fn test_trade_stats_all_wins() {
let pnl = [10.0, 20.0];
let hold = [1.0, 2.0];
let (wr, _aw, al, pf, _ah) = trade_stats(&pnl, &hold);
assert!((wr - 1.0).abs() < 1e-10);
assert!((al - 0.0).abs() < 1e-10);
assert!(pf.is_infinite());
}
#[test]
fn test_trade_stats_all_losses() {
let pnl = [-10.0, -20.0];
let hold = [1.0, 2.0];
let (wr, aw, _al, pf, _ah) = trade_stats(&pnl, &hold);
assert!((wr - 0.0).abs() < 1e-10);
assert!((aw - 0.0).abs() < 1e-10);
assert!((pf - 0.0).abs() < 1e-10);
}
#[test]
#[should_panic(expected = "pnl must be non-empty")]
fn test_trade_stats_empty() {
trade_stats(&[], &[]);
}
#[test]
fn test_monthly_contribution_basic() {
let returns = [0.01, 0.02, -0.01, 0.03, -0.02];
let months = [0, 0, 1, 1, 2];
let (m, c) = monthly_contribution(&returns, &months);
assert_eq!(m, vec![0, 1, 2]);
assert!((c[0] - 0.03).abs() < 1e-10);
assert!((c[1] - 0.02).abs() < 1e-10);
assert!((c[2] - (-0.02)).abs() < 1e-10);
}
#[test]
fn test_monthly_contribution_nan_skipped() {
let returns = [0.01, f64::NAN, 0.03];
let months = [0, 0, 1];
let (m, c) = monthly_contribution(&returns, &months);
assert_eq!(m, vec![0, 1]);
assert!((c[0] - 0.01).abs() < 1e-10);
assert!((c[1] - 0.03).abs() < 1e-10);
}
#[test]
fn test_monthly_contribution_empty() {
let (m, c) = monthly_contribution(&[], &[]);
assert!(m.is_empty());
assert!(c.is_empty());
}
#[test]
fn test_signal_attribution_basic() {
let returns = [0.05, -0.02, 0.03, 0.01];
let labels = [1, -1, 2, 1];
let (l, c) = signal_attribution(&returns, &labels);
assert_eq!(l, vec![-1, 1, 2]);
assert!((c[0] - (-0.02)).abs() < 1e-10);
assert!((c[1] - 0.06).abs() < 1e-10); assert!((c[2] - 0.03).abs() < 1e-10);
}
#[test]
fn test_signal_attribution_nan_skipped() {
let returns = [0.05, f64::NAN];
let labels = [1, 2];
let (l, c) = signal_attribution(&returns, &labels);
assert_eq!(l, vec![1]);
assert!((c[0] - 0.05).abs() < 1e-10);
}
#[test]
fn test_extract_trades_basic() {
let positions = [0.0, 1.0, 1.0, 0.0, -1.0, -1.0];
let strat_ret = [0.0, 0.01, 0.02, 0.0, -0.01, 0.03];
let (pnl, hold) = extract_trades(&positions, &strat_ret);
assert_eq!(pnl.len(), 2);
assert_eq!(hold.len(), 2);
assert!((pnl[0] - 0.03).abs() < 1e-10);
assert!((hold[0] - 2.0).abs() < 1e-10);
assert!((pnl[1] - 0.02).abs() < 1e-10);
assert!((hold[1] - 2.0).abs() < 1e-10);
}
#[test]
fn test_extract_trades_all_flat() {
let positions = [0.0, 0.0, 0.0];
let strat_ret = [0.01, 0.02, 0.03];
let (pnl, hold) = extract_trades(&positions, &strat_ret);
assert!(pnl.is_empty());
assert!(hold.is_empty());
}
#[test]
fn test_extract_trades_empty() {
let (pnl, hold) = extract_trades(&[], &[]);
assert!(pnl.is_empty());
assert!(hold.is_empty());
}
#[test]
fn test_extract_trades_single_bar_trade() {
let positions = [0.0, 1.0, 0.0];
let strat_ret = [0.0, 0.05, 0.0];
let (pnl, hold) = extract_trades(&positions, &strat_ret);
assert_eq!(pnl.len(), 1);
assert!((pnl[0] - 0.05).abs() < 1e-10);
assert!((hold[0] - 1.0).abs() < 1e-10);
}
}