use rustrade_core::Symbol;
use rustrade_risk::{CircuitBreakerConfig, SessionPnlConfig, SizingConfig};
use crate::error::{Error, Result};
use crate::fees::FeeModel;
use crate::slippage::SlippageModel;
#[derive(Debug, Clone)]
pub struct BacktestConfig {
pub symbols: Vec<Symbol>,
pub initial_cash: f64,
pub sizing: SizingConfig,
pub slippage: SlippageModel,
pub fees: FeeModel,
pub contract_value: f64,
pub risk_free_rate: f64,
pub periods_per_year: u32,
pub session_pnl: Option<SessionPnlConfig>,
pub circuit_breaker: Option<CircuitBreakerConfig>,
}
impl BacktestConfig {
pub fn symbol(&self) -> &Symbol {
assert_eq!(
self.symbols.len(),
1,
"BacktestConfig::symbol() is only valid for single-symbol backtests; \
this config has {} symbols. Use BacktestConfig::symbols instead.",
self.symbols.len()
);
&self.symbols[0]
}
}
impl BacktestConfig {
pub fn builder() -> BacktestConfigBuilder {
BacktestConfigBuilder::default()
}
}
#[derive(Debug, Clone, Default)]
pub struct BacktestConfigBuilder {
symbols: Vec<Symbol>,
initial_cash: Option<f64>,
sizing: Option<SizingConfig>,
slippage: Option<SlippageModel>,
fees: Option<FeeModel>,
contract_value: Option<f64>,
risk_free_rate: Option<f64>,
periods_per_year: Option<u32>,
session_pnl: Option<SessionPnlConfig>,
circuit_breaker: Option<CircuitBreakerConfig>,
}
impl BacktestConfigBuilder {
pub fn symbol(mut self, sym: impl Into<Symbol>) -> Self {
self.symbols = vec![sym.into()];
self
}
pub fn symbols<I, S>(mut self, syms: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<Symbol>,
{
self.symbols = syms.into_iter().map(Into::into).collect();
self
}
pub fn initial_cash(mut self, cash: f64) -> Self {
self.initial_cash = Some(cash);
self
}
pub fn sizing(mut self, sizing: SizingConfig) -> Self {
self.sizing = Some(sizing);
self
}
pub fn slippage(mut self, m: SlippageModel) -> Self {
self.slippage = Some(m);
self
}
pub fn fees(mut self, m: FeeModel) -> Self {
self.fees = Some(m);
self
}
pub fn contract_value(mut self, cv: f64) -> Self {
self.contract_value = Some(cv);
self
}
pub fn risk_free_rate(mut self, r: f64) -> Self {
self.risk_free_rate = Some(r);
self
}
pub fn periods_per_year(mut self, n: u32) -> Self {
self.periods_per_year = Some(n);
self
}
pub fn session_pnl(mut self, cfg: SessionPnlConfig) -> Self {
self.session_pnl = Some(cfg);
self
}
pub fn circuit_breaker(mut self, cfg: CircuitBreakerConfig) -> Self {
self.circuit_breaker = Some(cfg);
self
}
pub fn build(self) -> Result<BacktestConfig> {
if self.symbols.is_empty() {
return Err(Error::Config(
"BacktestConfig requires at least one symbol".into(),
));
}
let initial_cash = self.initial_cash.unwrap_or(10_000.0);
if !initial_cash.is_finite() || initial_cash <= 0.0 {
return Err(Error::Config(
"BacktestConfig.initial_cash must be a finite positive number".into(),
));
}
let contract_value = self.contract_value.unwrap_or(1.0);
if !contract_value.is_finite() || contract_value <= 0.0 {
return Err(Error::Config(
"BacktestConfig.contract_value must be a finite positive number".into(),
));
}
let risk_free_rate = self.risk_free_rate.unwrap_or(0.0);
if !risk_free_rate.is_finite() {
return Err(Error::Config(
"BacktestConfig.risk_free_rate must be finite".into(),
));
}
let periods_per_year = self.periods_per_year.unwrap_or(252);
if periods_per_year == 0 {
return Err(Error::Config(
"BacktestConfig.periods_per_year must be > 0".into(),
));
}
if let Some(sp) = &self.session_pnl
&& sp.loss_limit.is_nan()
{
return Err(Error::Config(
"BacktestConfig.session_pnl.loss_limit must not be NaN".into(),
));
}
Ok(BacktestConfig {
symbols: self.symbols,
initial_cash,
sizing: self.sizing.unwrap_or_default(),
slippage: self.slippage.unwrap_or_default(),
fees: self.fees.unwrap_or_default(),
contract_value,
risk_free_rate,
periods_per_year,
session_pnl: self.session_pnl,
circuit_breaker: self.circuit_breaker,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn requires_symbol() {
assert!(matches!(
BacktestConfig::builder().build(),
Err(Error::Config(_))
));
}
#[test]
fn rejects_non_positive_cash() {
let r = BacktestConfig::builder()
.symbol("BTCUSDT")
.initial_cash(-100.0)
.build();
assert!(matches!(r, Err(Error::Config(_))));
}
#[test]
fn rejects_non_positive_contract_value() {
let r = BacktestConfig::builder()
.symbol("X")
.contract_value(0.0)
.build();
assert!(matches!(r, Err(Error::Config(_))));
}
#[test]
fn rejects_zero_periods_per_year() {
let r = BacktestConfig::builder()
.symbol("X")
.periods_per_year(0)
.build();
assert!(matches!(r, Err(Error::Config(_))));
}
#[test]
fn rejects_nan_risk_free_rate() {
let r = BacktestConfig::builder()
.symbol("X")
.risk_free_rate(f64::NAN)
.build();
assert!(matches!(r, Err(Error::Config(_))));
}
#[test]
fn defaults_for_optional_fields() {
let c = BacktestConfig::builder().symbol("X").build().unwrap();
assert_eq!(c.initial_cash, 10_000.0);
assert_eq!(c.contract_value, 1.0);
assert_eq!(c.slippage, SlippageModel::Zero);
assert_eq!(c.risk_free_rate, 0.0);
assert_eq!(c.periods_per_year, 252);
}
#[test]
fn multi_symbol_config_round_trips() {
let c = BacktestConfig::builder()
.symbols(["BTCUSDT", "ETHUSDT", "SOLUSDT"])
.build()
.unwrap();
assert_eq!(c.symbols.len(), 3);
assert_eq!(c.symbols[0].as_str(), "BTCUSDT");
assert_eq!(c.symbols[2].as_str(), "SOLUSDT");
}
#[test]
fn symbol_accessor_panics_on_multi_symbol() {
let c = BacktestConfig::builder()
.symbols(["A", "B"])
.build()
.unwrap();
let r = std::panic::catch_unwind(|| {
let _ = c.symbol();
});
assert!(r.is_err());
}
#[test]
fn symbol_accessor_works_on_single_symbol() {
let c = BacktestConfig::builder().symbol("X").build().unwrap();
assert_eq!(c.symbol().as_str(), "X");
}
}