use crate::indicators::Candle;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Side {
Long,
Short,
}
#[derive(Debug, Clone, Copy)]
pub struct Position {
pub side: Side,
pub quantity: f64,
pub entry_price: f64,
pub entry_timestamp: u64,
}
#[derive(Debug, Clone, Copy)]
pub struct Trade {
pub side: Side,
pub quantity: f64,
pub entry_price: f64,
pub exit_price: f64,
pub entry_timestamp: u64,
pub exit_timestamp: u64,
pub pnl: f64,
pub fees_paid: f64,
}
#[derive(Debug, Clone, Copy)]
pub enum Quantity {
Fixed(f64),
AllCash,
PercentCash(f64),
}
#[derive(Debug, Clone, Copy)]
pub enum Action {
EnterLong(Quantity),
EnterShort(Quantity),
Exit,
Hold,
}
pub struct Context<'a> {
pub portfolio: &'a Portfolio,
pub candle_index: usize,
pub current_price: f64,
}
pub trait Strategy {
fn on_candle(&mut self, candle: &Candle, ctx: &Context) -> Action;
fn on_start(&mut self) {}
fn on_finish(&mut self) {}
}
#[derive(Debug, Clone)]
pub struct Portfolio {
pub cash: f64,
pub position: Option<Position>,
pub equity_curve: Vec<(u64, f64)>,
pub trades: Vec<Trade>,
}
impl Portfolio {
pub fn new(initial_cash: f64) -> Self {
Self {
cash: initial_cash,
position: None,
equity_curve: Vec::new(),
trades: Vec::new(),
}
}
pub fn equity(&self, current_price: f64) -> f64 {
let pos_value = match self.position {
None => 0.0,
Some(p) => match p.side {
Side::Long => p.quantity * current_price,
Side::Short => p.quantity * (2.0 * p.entry_price - current_price),
},
};
self.cash + pos_value
}
}
#[derive(Debug, Clone)]
pub struct BacktestConfig {
pub initial_cash: f64,
pub fee_rate: f64,
pub slippage: f64,
pub periods_per_year: f64,
}
impl Default for BacktestConfig {
fn default() -> Self {
Self {
initial_cash: 10_000.0,
fee_rate: 0.0,
slippage: 0.0,
periods_per_year: 252.0,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct Metrics {
pub final_equity: f64,
pub total_return: f64,
pub max_drawdown: f64,
pub sharpe: f64,
pub win_rate: f64,
pub trade_count: usize,
pub profit_factor: f64,
}
#[derive(Debug, Clone)]
pub struct BacktestResult {
pub portfolio: Portfolio,
pub metrics: Metrics,
}
#[derive(Debug, Clone)]
pub struct Backtester {
pub config: BacktestConfig,
}
impl Backtester {
pub fn new(config: BacktestConfig) -> Self {
Self { config }
}
pub fn run<S: Strategy>(&self, candles: &[Candle], strategy: &mut S) -> BacktestResult {
let mut portfolio = Portfolio::new(self.config.initial_cash);
strategy.on_start();
for (i, candle) in candles.iter().enumerate() {
let price = candle.close;
let action = {
let ctx = Context {
portfolio: &portfolio,
candle_index: i,
current_price: price,
};
strategy.on_candle(candle, &ctx)
};
apply_action(&mut portfolio, action, candle, &self.config);
let equity = portfolio.equity(price);
portfolio.equity_curve.push((candle.timestamp, equity));
}
strategy.on_finish();
let metrics = compute_metrics(&portfolio, &self.config);
BacktestResult { portfolio, metrics }
}
}
fn fill_buy_price(close: f64, slippage: f64) -> f64 {
close * (1.0 + slippage)
}
fn fill_sell_price(close: f64, slippage: f64) -> f64 {
close * (1.0 - slippage)
}
fn resolve_quantity(qty: Quantity, cash: f64, fill_price: f64, fee_rate: f64) -> Option<f64> {
if fill_price <= 0.0 {
return None;
}
let effective_unit_cost = fill_price * (1.0 + fee_rate);
let units = match qty {
Quantity::Fixed(q) => q,
Quantity::AllCash => {
if cash <= 0.0 {
return None;
}
cash / effective_unit_cost
}
Quantity::PercentCash(pct) => {
if cash <= 0.0 {
return None;
}
let pct = pct.clamp(0.0, 1.0);
(cash * pct) / effective_unit_cost
}
};
if units > 0.0 {
Some(units)
} else {
None
}
}
fn close_position(portfolio: &mut Portfolio, candle: &Candle, cfg: &BacktestConfig) {
let Some(pos) = portfolio.position.take() else {
return;
};
let exit_price = match pos.side {
Side::Long => fill_sell_price(candle.close, cfg.slippage),
Side::Short => fill_buy_price(candle.close, cfg.slippage),
};
let exit_fee = exit_price * pos.quantity * cfg.fee_rate;
match pos.side {
Side::Long => {
portfolio.cash += pos.quantity * exit_price - exit_fee;
}
Side::Short => {
portfolio.cash -= pos.quantity * exit_price + exit_fee;
}
}
let entry_fee = pos.entry_price * pos.quantity * cfg.fee_rate;
let gross_pnl = match pos.side {
Side::Long => pos.quantity * (exit_price - pos.entry_price),
Side::Short => pos.quantity * (pos.entry_price - exit_price),
};
let total_fees = entry_fee + exit_fee;
portfolio.trades.push(Trade {
side: pos.side,
quantity: pos.quantity,
entry_price: pos.entry_price,
exit_price,
entry_timestamp: pos.entry_timestamp,
exit_timestamp: candle.timestamp,
pnl: gross_pnl - total_fees,
fees_paid: total_fees,
});
}
fn open_position(
portfolio: &mut Portfolio,
side: Side,
qty: Quantity,
candle: &Candle,
cfg: &BacktestConfig,
) {
let fill_price = match side {
Side::Long => fill_buy_price(candle.close, cfg.slippage),
Side::Short => fill_sell_price(candle.close, cfg.slippage),
};
let Some(units) = resolve_quantity(qty, portfolio.cash, fill_price, cfg.fee_rate) else {
return; };
let entry_fee = fill_price * units * cfg.fee_rate;
match side {
Side::Long => {
let cost = units * fill_price + entry_fee;
if cost > portfolio.cash {
return; }
portfolio.cash -= cost;
}
Side::Short => {
portfolio.cash += units * fill_price - entry_fee;
}
}
portfolio.position = Some(Position {
side,
quantity: units,
entry_price: fill_price,
entry_timestamp: candle.timestamp,
});
}
fn apply_action(portfolio: &mut Portfolio, action: Action, candle: &Candle, cfg: &BacktestConfig) {
match action {
Action::Hold => {}
Action::Exit => close_position(portfolio, candle, cfg),
Action::EnterLong(qty) => {
if matches!(portfolio.position, Some(p) if p.side == Side::Short) {
close_position(portfolio, candle, cfg);
}
if portfolio.position.is_none() {
open_position(portfolio, Side::Long, qty, candle, cfg);
}
}
Action::EnterShort(qty) => {
if matches!(portfolio.position, Some(p) if p.side == Side::Long) {
close_position(portfolio, candle, cfg);
}
if portfolio.position.is_none() {
open_position(portfolio, Side::Short, qty, candle, cfg);
}
}
}
}
fn compute_metrics(portfolio: &Portfolio, cfg: &BacktestConfig) -> Metrics {
let final_equity = portfolio
.equity_curve
.last()
.map(|&(_, e)| e)
.unwrap_or(cfg.initial_cash);
let total_return = if cfg.initial_cash > 0.0 {
(final_equity - cfg.initial_cash) / cfg.initial_cash
} else {
0.0
};
let mut peak = cfg.initial_cash;
let mut max_dd = 0.0_f64;
for &(_, eq) in &portfolio.equity_curve {
if eq > peak {
peak = eq;
}
if peak > 0.0 {
let dd = (peak - eq) / peak;
if dd > max_dd {
max_dd = dd;
}
}
}
let sharpe = sharpe_from_equity_curve(&portfolio.equity_curve, cfg.periods_per_year);
let trade_count = portfolio.trades.len();
let (wins, gross_profit, gross_loss) =
portfolio
.trades
.iter()
.fold((0usize, 0.0_f64, 0.0_f64), |(w, gp, gl), t| {
if t.pnl > 0.0 {
(w + 1, gp + t.pnl, gl)
} else {
(w, gp, gl + t.pnl)
}
});
let win_rate = if trade_count == 0 {
0.0
} else {
wins as f64 / trade_count as f64
};
let profit_factor = if trade_count == 0 {
0.0
} else if gross_loss == 0.0 {
f64::INFINITY
} else {
gross_profit / gross_loss.abs()
};
Metrics {
final_equity,
total_return,
max_drawdown: max_dd,
sharpe,
win_rate,
trade_count,
profit_factor,
}
}
fn sharpe_from_equity_curve(curve: &[(u64, f64)], periods_per_year: f64) -> f64 {
if curve.len() < 2 {
return 0.0;
}
let returns: Vec<f64> = curve
.windows(2)
.filter_map(|w| {
let prev = w[0].1;
let cur = w[1].1;
if prev > 0.0 {
Some((cur - prev) / prev)
} else {
None
}
})
.collect();
if returns.len() < 2 {
return 0.0;
}
let mean = returns.iter().sum::<f64>() / returns.len() as f64;
let var = returns.iter().map(|r| (r - mean).powi(2)).sum::<f64>() / (returns.len() - 1) as f64;
let std = var.sqrt();
if std == 0.0 {
return 0.0;
}
(mean / std) * periods_per_year.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
fn ramp(n: usize) -> Vec<Candle> {
(1..=n)
.map(|i| Candle {
timestamp: i as u64,
open: i as f64,
high: i as f64 + 0.5,
low: i as f64 - 0.5,
close: i as f64,
volume: 1.0,
})
.collect()
}
struct BuyAndHold {
entered: bool,
}
impl Strategy for BuyAndHold {
fn on_candle(&mut self, _c: &Candle, _ctx: &Context) -> Action {
if !self.entered {
self.entered = true;
Action::EnterLong(Quantity::AllCash)
} else {
Action::Hold
}
}
}
struct OneTrade {
bar: usize,
}
impl Strategy for OneTrade {
fn on_candle(&mut self, _c: &Candle, ctx: &Context) -> Action {
self.bar = ctx.candle_index;
match ctx.candle_index {
0 => Action::EnterLong(Quantity::AllCash),
4 => Action::Exit,
_ => Action::Hold,
}
}
}
#[test]
fn buy_and_hold_appreciates_with_price() {
let bt = Backtester::new(BacktestConfig::default());
let candles = ramp(10);
let mut s = BuyAndHold { entered: false };
let res = bt.run(&candles, &mut s);
assert!(res.metrics.final_equity > res.metrics.final_equity * 0.0); assert!(res.metrics.total_return > 5.0);
assert_eq!(res.metrics.trade_count, 0);
}
#[test]
fn one_trade_is_recorded_with_correct_pnl() {
let bt = Backtester::new(BacktestConfig::default());
let candles = ramp(10);
let mut s = OneTrade { bar: 0 };
let res = bt.run(&candles, &mut s);
assert_eq!(res.portfolio.trades.len(), 1);
let t = res.portfolio.trades[0];
assert_eq!(t.side, Side::Long);
assert!((t.quantity - 10_000.0).abs() < 1e-9);
assert_eq!(t.entry_price, 1.0);
assert_eq!(t.exit_price, 5.0);
assert!((t.pnl - 40_000.0).abs() < 1e-9);
assert!(res.portfolio.position.is_none());
assert_eq!(res.metrics.win_rate, 1.0);
}
#[test]
fn fees_reduce_pnl() {
let cfg = BacktestConfig {
fee_rate: 0.01, ..Default::default()
};
let bt = Backtester::new(cfg);
let mut s = OneTrade { bar: 0 };
let res = bt.run(&ramp(10), &mut s);
let t = res.portfolio.trades[0];
assert!(t.pnl < 40_000.0);
assert!(t.fees_paid > 0.0);
}
#[test]
fn slippage_widens_spread() {
let cfg = BacktestConfig {
slippage: 0.01,
..Default::default()
};
let bt = Backtester::new(cfg);
let mut s = OneTrade { bar: 0 };
let res = bt.run(&ramp(10), &mut s);
let t = res.portfolio.trades[0];
assert!(t.entry_price > 1.0); assert!(t.exit_price < 5.0); }
struct ShortStrategy;
impl Strategy for ShortStrategy {
fn on_candle(&mut self, _c: &Candle, ctx: &Context) -> Action {
match ctx.candle_index {
0 => Action::EnterShort(Quantity::Fixed(100.0)),
4 => Action::Exit,
_ => Action::Hold,
}
}
}
#[test]
fn short_in_uptrend_loses_money() {
let bt = Backtester::new(BacktestConfig::default());
let res = bt.run(&ramp(10), &mut ShortStrategy);
assert_eq!(res.portfolio.trades.len(), 1);
let t = res.portfolio.trades[0];
assert_eq!(t.side, Side::Short);
assert!((t.pnl - (-400.0)).abs() < 1e-9);
assert!(res.metrics.total_return < 0.0);
}
#[test]
fn equity_curve_is_sampled_each_bar() {
let bt = Backtester::new(BacktestConfig::default());
let candles = ramp(7);
let mut s = BuyAndHold { entered: false };
let res = bt.run(&candles, &mut s);
assert_eq!(res.portfolio.equity_curve.len(), candles.len());
}
#[test]
fn flipping_long_to_short_closes_first_position() {
struct Flip;
impl Strategy for Flip {
fn on_candle(&mut self, _c: &Candle, ctx: &Context) -> Action {
match ctx.candle_index {
0 => Action::EnterLong(Quantity::Fixed(100.0)),
3 => Action::EnterShort(Quantity::Fixed(50.0)),
_ => Action::Hold,
}
}
}
let bt = Backtester::new(BacktestConfig::default());
let res = bt.run(&ramp(10), &mut Flip);
assert_eq!(
res.portfolio.trades.len(),
1,
"long should close when short order arrives"
);
assert!(matches!(
res.portfolio.position,
Some(p) if p.side == Side::Short
));
}
#[test]
fn metrics_are_zero_for_empty_input() {
let bt = Backtester::new(BacktestConfig::default());
let res = bt.run(&[], &mut BuyAndHold { entered: false });
assert_eq!(res.metrics.final_equity, 10_000.0);
assert_eq!(res.metrics.trade_count, 0);
assert_eq!(res.metrics.total_return, 0.0);
assert_eq!(res.metrics.max_drawdown, 0.0);
}
}