use crate::Decimal;
use crate::execution::Side;
use crate::position::inventory::InventoryPosition;
use crate::position::pnl::PnL;
use crate::strategy::quote::Quote;
use super::data::{HistoricalDataSource, MarketTick};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct SimulatedFill {
pub side: Side,
pub price: Decimal,
pub quantity: Decimal,
pub timestamp: u64,
pub fee: Decimal,
}
impl SimulatedFill {
#[must_use]
pub fn new(side: Side, price: Decimal, quantity: Decimal, timestamp: u64) -> Self {
Self {
side,
price,
quantity,
timestamp,
fee: Decimal::ZERO,
}
}
#[must_use]
pub fn with_fee(
side: Side,
price: Decimal,
quantity: Decimal,
timestamp: u64,
fee: Decimal,
) -> Self {
Self {
side,
price,
quantity,
timestamp,
fee,
}
}
#[must_use]
pub fn notional(&self) -> Decimal {
self.price * self.quantity
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Default)]
pub enum SlippageModel {
#[default]
None,
Fixed(Decimal),
Percentage(Decimal),
VolatilityBased {
multiplier: Decimal,
},
}
impl SlippageModel {
#[must_use]
pub fn calculate_slippage(&self, price: Decimal, volatility: Decimal) -> Decimal {
match self {
SlippageModel::None => Decimal::ZERO,
SlippageModel::Fixed(amount) => *amount,
SlippageModel::Percentage(pct) => price * pct,
SlippageModel::VolatilityBased { multiplier } => price * volatility * multiplier,
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct BacktestConfig {
pub initial_capital: Decimal,
pub fee_rate: Decimal,
pub tick_size: Decimal,
pub lot_size: Decimal,
pub slippage: SlippageModel,
pub default_order_size: Decimal,
pub record_equity_curve: bool,
pub record_trades: bool,
}
impl Default for BacktestConfig {
fn default() -> Self {
Self {
initial_capital: Decimal::from(100_000),
fee_rate: Decimal::ZERO,
tick_size: Decimal::from_str_exact("0.01").unwrap(),
lot_size: Decimal::from_str_exact("0.001").unwrap(),
slippage: SlippageModel::None,
default_order_size: Decimal::ONE,
record_equity_curve: true,
record_trades: true,
}
}
}
impl BacktestConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_initial_capital(mut self, capital: Decimal) -> Self {
self.initial_capital = capital;
self
}
#[must_use]
pub fn with_fee_rate(mut self, rate: Decimal) -> Self {
self.fee_rate = rate;
self
}
#[must_use]
pub fn with_tick_size(mut self, size: Decimal) -> Self {
self.tick_size = size;
self
}
#[must_use]
pub fn with_lot_size(mut self, size: Decimal) -> Self {
self.lot_size = size;
self
}
#[must_use]
pub fn with_slippage(mut self, slippage: SlippageModel) -> Self {
self.slippage = slippage;
self
}
#[must_use]
pub fn with_default_order_size(mut self, size: Decimal) -> Self {
self.default_order_size = size;
self
}
#[must_use]
pub fn with_record_equity_curve(mut self, record: bool) -> Self {
self.record_equity_curve = record;
self
}
#[must_use]
pub fn with_record_trades(mut self, record: bool) -> Self {
self.record_trades = record;
self
}
}
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct BacktestResult {
pub total_pnl: Decimal,
pub total_fees: Decimal,
pub net_pnl: Decimal,
pub num_trades: u64,
pub num_ticks: u64,
pub start_time: u64,
pub end_time: u64,
pub max_position: Decimal,
pub final_position: Decimal,
pub equity_curve: Vec<(u64, Decimal)>,
pub trades: Vec<SimulatedFill>,
pub max_drawdown: Decimal,
pub sharpe_ratio: Option<Decimal>,
}
impl BacktestResult {
#[must_use]
pub fn win_rate(&self) -> Decimal {
if self.num_trades == 0 {
return Decimal::ZERO;
}
if self.net_pnl > Decimal::ZERO {
Decimal::ONE
} else {
Decimal::ZERO
}
}
#[must_use]
pub fn avg_trade_pnl(&self) -> Decimal {
if self.num_trades == 0 {
Decimal::ZERO
} else {
self.net_pnl / Decimal::from(self.num_trades)
}
}
#[must_use]
pub fn duration_ms(&self) -> u64 {
self.end_time.saturating_sub(self.start_time)
}
#[must_use]
pub fn return_on_capital(&self, initial_capital: Decimal) -> Decimal {
if initial_capital > Decimal::ZERO {
self.net_pnl / initial_capital
} else {
Decimal::ZERO
}
}
}
pub trait BacktestStrategy {
fn on_tick(&mut self, tick: &MarketTick, position: &InventoryPosition) -> Option<Quote>;
fn on_fill(&mut self, fill: &SimulatedFill);
fn reset(&mut self);
}
#[derive(Debug)]
pub struct BacktestEngine<S: BacktestStrategy, D: HistoricalDataSource> {
config: BacktestConfig,
strategy: S,
data_source: D,
position: InventoryPosition,
pnl: PnL,
equity_curve: Vec<(u64, Decimal)>,
trades: Vec<SimulatedFill>,
total_fees: Decimal,
max_position: Decimal,
peak_equity: Decimal,
max_drawdown: Decimal,
}
impl<S: BacktestStrategy, D: HistoricalDataSource> BacktestEngine<S, D> {
#[must_use]
pub fn new(config: BacktestConfig, strategy: S, data_source: D) -> Self {
let initial_capital = config.initial_capital;
Self {
config,
strategy,
data_source,
position: InventoryPosition::new(),
pnl: PnL::new(),
equity_curve: Vec::new(),
trades: Vec::new(),
total_fees: Decimal::ZERO,
max_position: Decimal::ZERO,
peak_equity: initial_capital,
max_drawdown: Decimal::ZERO,
}
}
pub fn run(&mut self) -> BacktestResult {
self.run_with_progress(|_, _| {})
}
pub fn run_with_progress<F: FnMut(usize, usize)>(&mut self, mut callback: F) -> BacktestResult {
let total_ticks = self.data_source.len();
let mut num_ticks = 0u64;
let mut start_time = 0u64;
let mut end_time = 0u64;
while let Some(tick) = self.data_source.next_tick() {
if num_ticks == 0 {
start_time = tick.timestamp;
}
end_time = tick.timestamp;
if let Some(quote) = self.strategy.on_tick(&tick, &self.position) {
self.simulate_fills(&tick, "e);
}
let mid_price = tick.mid_price();
self.pnl.unrealized = self.position.quantity * mid_price;
self.pnl.total = self.pnl.realized + self.pnl.unrealized;
let equity = self.config.initial_capital + self.pnl.total - self.total_fees;
if self.config.record_equity_curve {
self.equity_curve.push((tick.timestamp, equity));
}
if equity > self.peak_equity {
self.peak_equity = equity;
}
let drawdown = self.peak_equity - equity;
if drawdown > self.max_drawdown {
self.max_drawdown = drawdown;
}
num_ticks += 1;
callback(num_ticks as usize, total_ticks);
}
BacktestResult {
total_pnl: self.pnl.total,
total_fees: self.total_fees,
net_pnl: self.pnl.total - self.total_fees,
num_trades: self.trades.len() as u64,
num_ticks,
start_time,
end_time,
max_position: self.max_position,
final_position: self.position.quantity,
equity_curve: if self.config.record_equity_curve {
self.equity_curve.clone()
} else {
Vec::new()
},
trades: if self.config.record_trades {
self.trades.clone()
} else {
Vec::new()
},
max_drawdown: self.max_drawdown,
sharpe_ratio: self.calculate_sharpe_ratio(),
}
}
fn simulate_fills(&mut self, tick: &MarketTick, quote: &Quote) {
if tick.ask_price <= quote.bid_price {
let fill_price = self.apply_slippage(quote.bid_price, Side::Buy);
let fill = self.create_fill(Side::Buy, fill_price, tick.timestamp);
self.process_fill(fill);
}
if tick.bid_price >= quote.ask_price {
let fill_price = self.apply_slippage(quote.ask_price, Side::Sell);
let fill = self.create_fill(Side::Sell, fill_price, tick.timestamp);
self.process_fill(fill);
}
}
fn apply_slippage(&self, price: Decimal, side: Side) -> Decimal {
let slippage = self
.config
.slippage
.calculate_slippage(price, Decimal::ZERO);
match side {
Side::Buy => price + slippage,
Side::Sell => price - slippage,
}
}
fn create_fill(&self, side: Side, price: Decimal, timestamp: u64) -> SimulatedFill {
let quantity = self.config.default_order_size;
let notional = price * quantity;
let fee = notional * self.config.fee_rate;
SimulatedFill::with_fee(side, price, quantity, timestamp, fee)
}
fn process_fill(&mut self, fill: SimulatedFill) {
let signed_qty = match fill.side {
Side::Buy => fill.quantity,
Side::Sell => -fill.quantity,
};
self.position
.update_fill(signed_qty, fill.price, fill.timestamp);
let cash_flow = match fill.side {
Side::Buy => -fill.notional(),
Side::Sell => fill.notional(),
};
self.pnl.add_realized(cash_flow);
self.total_fees += fill.fee;
let abs_position = self.position.quantity.abs();
if abs_position > self.max_position {
self.max_position = abs_position;
}
self.strategy.on_fill(&fill);
if self.config.record_trades {
self.trades.push(fill);
}
}
fn calculate_sharpe_ratio(&self) -> Option<Decimal> {
if self.equity_curve.len() < 2 {
return None;
}
let returns: Vec<Decimal> = self
.equity_curve
.windows(2)
.filter_map(|w| {
if w[0].1 > Decimal::ZERO {
Some((w[1].1 - w[0].1) / w[0].1)
} else {
None
}
})
.collect();
if returns.is_empty() {
return None;
}
let n = Decimal::from(returns.len() as u64);
let mean: Decimal = returns.iter().sum::<Decimal>() / n;
let variance: Decimal = returns
.iter()
.map(|r| {
let diff = *r - mean;
diff * diff
})
.sum::<Decimal>()
/ n;
if variance <= Decimal::ZERO {
return None;
}
let std_dev = decimal_sqrt(variance)?;
if std_dev > Decimal::ZERO {
Some(mean / std_dev)
} else {
None
}
}
#[must_use]
pub fn get_state(&self) -> (&InventoryPosition, &PnL) {
(&self.position, &self.pnl)
}
#[must_use]
pub fn strategy(&self) -> &S {
&self.strategy
}
pub fn strategy_mut(&mut self) -> &mut S {
&mut self.strategy
}
pub fn reset(&mut self) {
self.data_source.reset();
self.strategy.reset();
self.position = InventoryPosition::new();
self.pnl = PnL::new();
self.equity_curve.clear();
self.trades.clear();
self.total_fees = Decimal::ZERO;
self.max_position = Decimal::ZERO;
self.peak_equity = self.config.initial_capital;
self.max_drawdown = Decimal::ZERO;
}
}
fn decimal_sqrt(n: Decimal) -> Option<Decimal> {
if n < Decimal::ZERO {
return None;
}
if n == Decimal::ZERO {
return Some(Decimal::ZERO);
}
let mut x = n;
let two = Decimal::TWO;
for _ in 0..20 {
let next = (x + n / x) / two;
if (next - x).abs() < Decimal::from_str_exact("0.0000001").unwrap() {
return Some(next);
}
x = next;
}
Some(x)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backtest::VecDataSource;
use crate::dec;
struct TestStrategy {
quote_spread: Decimal,
fills_received: Vec<SimulatedFill>,
}
impl TestStrategy {
fn new(spread: Decimal) -> Self {
Self {
quote_spread: spread,
fills_received: Vec::new(),
}
}
}
impl BacktestStrategy for TestStrategy {
fn on_tick(&mut self, tick: &MarketTick, _position: &InventoryPosition) -> Option<Quote> {
let mid = tick.mid_price();
let half_spread = self.quote_spread / Decimal::TWO;
Some(Quote {
bid_price: mid - half_spread,
bid_size: Decimal::ONE,
ask_price: mid + half_spread,
ask_size: Decimal::ONE,
timestamp: tick.timestamp,
})
}
fn on_fill(&mut self, fill: &SimulatedFill) {
self.fills_received.push(fill.clone());
}
fn reset(&mut self) {
self.fills_received.clear();
}
}
struct PassiveStrategy;
impl BacktestStrategy for PassiveStrategy {
fn on_tick(&mut self, _tick: &MarketTick, _position: &InventoryPosition) -> Option<Quote> {
None
}
fn on_fill(&mut self, _fill: &SimulatedFill) {}
fn reset(&mut self) {}
}
fn create_test_tick(timestamp: u64, bid: Decimal, ask: Decimal) -> MarketTick {
MarketTick::new(timestamp, bid, dec!(1.0), ask, dec!(1.0))
}
#[test]
fn test_simulated_fill_new() {
let fill = SimulatedFill::new(Side::Buy, dec!(100.0), dec!(0.1), 1000);
assert_eq!(fill.side, Side::Buy);
assert_eq!(fill.price, dec!(100.0));
assert_eq!(fill.quantity, dec!(0.1));
assert_eq!(fill.timestamp, 1000);
assert_eq!(fill.fee, Decimal::ZERO);
}
#[test]
fn test_simulated_fill_with_fee() {
let fill = SimulatedFill::with_fee(Side::Sell, dec!(100.0), dec!(0.1), 1000, dec!(0.01));
assert_eq!(fill.fee, dec!(0.01));
}
#[test]
fn test_simulated_fill_notional() {
let fill = SimulatedFill::new(Side::Buy, dec!(100.0), dec!(0.5), 1000);
assert_eq!(fill.notional(), dec!(50.0));
}
#[test]
fn test_slippage_model_none() {
let model = SlippageModel::None;
assert_eq!(
model.calculate_slippage(dec!(100.0), dec!(0.01)),
Decimal::ZERO
);
}
#[test]
fn test_slippage_model_fixed() {
let model = SlippageModel::Fixed(dec!(0.05));
assert_eq!(
model.calculate_slippage(dec!(100.0), dec!(0.01)),
dec!(0.05)
);
}
#[test]
fn test_slippage_model_percentage() {
let model = SlippageModel::Percentage(dec!(0.001)); assert_eq!(model.calculate_slippage(dec!(100.0), dec!(0.01)), dec!(0.1));
}
#[test]
fn test_slippage_model_volatility_based() {
let model = SlippageModel::VolatilityBased {
multiplier: dec!(2.0),
};
assert_eq!(model.calculate_slippage(dec!(100.0), dec!(0.01)), dec!(2.0));
}
#[test]
fn test_backtest_config_default() {
let config = BacktestConfig::default();
assert_eq!(config.initial_capital, Decimal::from(100_000));
assert_eq!(config.fee_rate, Decimal::ZERO);
assert!(config.record_equity_curve);
assert!(config.record_trades);
}
#[test]
fn test_backtest_config_builder() {
let config = BacktestConfig::new()
.with_initial_capital(dec!(50000.0))
.with_fee_rate(dec!(0.001))
.with_slippage(SlippageModel::Fixed(dec!(0.01)));
assert_eq!(config.initial_capital, dec!(50000.0));
assert_eq!(config.fee_rate, dec!(0.001));
}
#[test]
fn test_backtest_result_default() {
let result = BacktestResult::default();
assert_eq!(result.total_pnl, Decimal::ZERO);
assert_eq!(result.num_trades, 0);
}
#[test]
fn test_backtest_result_avg_trade_pnl() {
let result = BacktestResult {
net_pnl: dec!(100.0),
num_trades: 10,
..Default::default()
};
assert_eq!(result.avg_trade_pnl(), dec!(10.0));
}
#[test]
fn test_backtest_result_duration() {
let result = BacktestResult {
start_time: 1000,
end_time: 5000,
..Default::default()
};
assert_eq!(result.duration_ms(), 4000);
}
#[test]
fn test_backtest_engine_passive() {
let ticks = vec![
create_test_tick(1000, dec!(100.0), dec!(100.2)),
create_test_tick(1001, dec!(100.1), dec!(100.3)),
create_test_tick(1002, dec!(100.2), dec!(100.4)),
];
let mut engine = BacktestEngine::new(
BacktestConfig::default(),
PassiveStrategy,
VecDataSource::new(ticks),
);
let result = engine.run();
assert_eq!(result.num_ticks, 3);
assert_eq!(result.num_trades, 0);
assert_eq!(result.start_time, 1000);
assert_eq!(result.end_time, 1002);
}
#[test]
fn test_backtest_engine_with_fills() {
let ticks = vec![
create_test_tick(1000, dec!(100.0), dec!(100.2)), create_test_tick(1001, dec!(99.95), dec!(99.98)), ];
let strategy = TestStrategy::new(dec!(0.02));
let mut engine = BacktestEngine::new(
BacktestConfig::default().with_default_order_size(dec!(1.0)),
strategy,
VecDataSource::new(ticks),
);
let result = engine.run();
assert_eq!(result.num_ticks, 2);
assert_eq!(result.start_time, 1000);
assert_eq!(result.end_time, 1001);
}
#[test]
fn test_backtest_engine_with_fees() {
let ticks = vec![
create_test_tick(1000, dec!(100.0), dec!(100.2)),
create_test_tick(1001, dec!(99.8), dec!(100.0)), ];
let strategy = TestStrategy::new(dec!(0.2));
let config = BacktestConfig::default()
.with_fee_rate(dec!(0.001))
.with_default_order_size(dec!(1.0));
let mut engine = BacktestEngine::new(config, strategy, VecDataSource::new(ticks));
let result = engine.run();
if result.num_trades > 0 {
assert!(result.total_fees > Decimal::ZERO);
}
}
#[test]
fn test_backtest_engine_reset() {
let ticks = vec![
create_test_tick(1000, dec!(100.0), dec!(100.2)),
create_test_tick(1001, dec!(100.1), dec!(100.3)),
];
let mut engine = BacktestEngine::new(
BacktestConfig::default(),
PassiveStrategy,
VecDataSource::new(ticks),
);
let result1 = engine.run();
assert_eq!(result1.num_ticks, 2);
engine.reset();
let result2 = engine.run();
assert_eq!(result2.num_ticks, 2);
}
#[test]
fn test_backtest_engine_equity_curve() {
let ticks = vec![
create_test_tick(1000, dec!(100.0), dec!(100.2)),
create_test_tick(1001, dec!(100.1), dec!(100.3)),
create_test_tick(1002, dec!(100.2), dec!(100.4)),
];
let config = BacktestConfig::default().with_record_equity_curve(true);
let mut engine = BacktestEngine::new(config, PassiveStrategy, VecDataSource::new(ticks));
let result = engine.run();
assert_eq!(result.equity_curve.len(), 3);
}
#[test]
fn test_backtest_engine_no_equity_curve() {
let ticks = vec![
create_test_tick(1000, dec!(100.0), dec!(100.2)),
create_test_tick(1001, dec!(100.1), dec!(100.3)),
];
let config = BacktestConfig::default().with_record_equity_curve(false);
let mut engine = BacktestEngine::new(config, PassiveStrategy, VecDataSource::new(ticks));
let result = engine.run();
assert!(result.equity_curve.is_empty());
}
#[test]
fn test_backtest_engine_get_state() {
let ticks = vec![create_test_tick(1000, dec!(100.0), dec!(100.2))];
let engine = BacktestEngine::new(
BacktestConfig::default(),
PassiveStrategy,
VecDataSource::new(ticks),
);
let (position, pnl) = engine.get_state();
assert_eq!(position.quantity, Decimal::ZERO);
assert_eq!(pnl.total, Decimal::ZERO);
}
#[test]
fn test_decimal_sqrt() {
assert_eq!(decimal_sqrt(Decimal::ZERO), Some(Decimal::ZERO));
assert!(decimal_sqrt(dec!(-1.0)).is_none());
let sqrt_4 = decimal_sqrt(dec!(4.0)).unwrap();
assert!((sqrt_4 - dec!(2.0)).abs() < dec!(0.0001));
let sqrt_2 = decimal_sqrt(dec!(2.0)).unwrap();
assert!((sqrt_2 - dec!(1.414)).abs() < dec!(0.001));
}
#[test]
fn test_backtest_engine_progress_callback() {
let ticks = vec![
create_test_tick(1000, dec!(100.0), dec!(100.2)),
create_test_tick(1001, dec!(100.1), dec!(100.3)),
create_test_tick(1002, dec!(100.2), dec!(100.4)),
];
let mut progress_calls = 0;
let mut engine = BacktestEngine::new(
BacktestConfig::default(),
PassiveStrategy,
VecDataSource::new(ticks),
);
engine.run_with_progress(|current, total| {
progress_calls += 1;
assert!(current <= total);
});
assert_eq!(progress_calls, 3);
}
#[test]
fn test_backtest_result_avg_trade_pnl_zero_trades() {
let result = BacktestResult {
net_pnl: dec!(100.0),
num_trades: 0,
..Default::default()
};
assert_eq!(result.avg_trade_pnl(), Decimal::ZERO);
}
#[test]
fn test_backtest_config_with_slippage() {
let config = BacktestConfig::default().with_slippage(SlippageModel::Fixed(dec!(0.01)));
assert!(matches!(config.slippage, SlippageModel::Fixed(_)));
}
#[test]
fn test_backtest_config_with_initial_capital() {
let config = BacktestConfig::default().with_initial_capital(dec!(50000.0));
assert_eq!(config.initial_capital, dec!(50000.0));
}
#[test]
fn test_backtest_config_with_record_trades() {
let config = BacktestConfig::default().with_record_trades(true);
assert!(config.record_trades);
}
#[test]
fn test_backtest_engine_with_actual_fills() {
let ticks = vec![
create_test_tick(1000, dec!(100.0), dec!(100.2)), create_test_tick(1001, dec!(99.0), dec!(99.1)), create_test_tick(1002, dec!(101.0), dec!(101.2)), ];
let strategy = TestStrategy::new(dec!(2.0));
let config = BacktestConfig::default()
.with_default_order_size(dec!(1.0))
.with_record_trades(true);
let mut engine = BacktestEngine::new(config, strategy, VecDataSource::new(ticks));
let result = engine.run();
assert_eq!(result.num_ticks, 3);
}
#[test]
fn test_backtest_engine_sharpe_calculation() {
let mut ticks = Vec::new();
for i in 0..20 {
let price = dec!(100.0) + Decimal::from(i % 3);
ticks.push(create_test_tick(1000 + i as u64, price, price + dec!(0.2)));
}
let config = BacktestConfig::default().with_record_equity_curve(true);
let mut engine = BacktestEngine::new(config, PassiveStrategy, VecDataSource::new(ticks));
let result = engine.run();
assert_eq!(result.equity_curve.len(), 20);
}
#[test]
fn test_backtest_engine_max_drawdown() {
let ticks = vec![
create_test_tick(1000, dec!(100.0), dec!(100.2)),
create_test_tick(1001, dec!(95.0), dec!(95.2)), create_test_tick(1002, dec!(90.0), dec!(90.2)), create_test_tick(1003, dec!(100.0), dec!(100.2)), ];
let config = BacktestConfig::default().with_record_equity_curve(true);
let mut engine = BacktestEngine::new(config, PassiveStrategy, VecDataSource::new(ticks));
let result = engine.run();
assert_eq!(result.num_ticks, 4);
}
#[test]
fn test_backtest_result_is_profitable() {
let profitable = BacktestResult {
net_pnl: dec!(100.0),
..Default::default()
};
assert!(profitable.net_pnl > Decimal::ZERO);
let unprofitable = BacktestResult {
net_pnl: dec!(-50.0),
..Default::default()
};
assert!(unprofitable.net_pnl < Decimal::ZERO);
}
}