mod cross_sectional;
mod live_bridge;
mod metrics;
mod monte_carlo;
mod sweep;
mod tearsheet;
mod walk_forward;
use chrono::{DateTime, Utc};
use polars::prelude::*;
pub use cross_sectional::{
assign_long_short_exposure, neutralize_factor, run_cross_sectional_backtest, winsorize_factor,
zscore_factor, CrossSectionalConfig,
};
pub use live_bridge::{
LiveBridge, LiveBridgeError, LiveSignalEvent, RecordingLiveBridge,
};
pub use metrics::{BacktestReport, PerformanceMetrics};
pub use tearsheet::{render_tearsheet_html, TearsheetOptions};
pub use monte_carlo::{
monte_carlo_trade_bootstrap, MonteCarloConfig, MonteCarloSummary,
monte_carlo_return_paths, MonteCarloReturnConfig, MonteCarloPathSummary,
};
pub use sweep::{run_param_sweep, single_param_variants, SweepVariant};
pub use walk_forward::{run_walk_forward, run_walk_forward_optimize, WalkForwardConfig};
#[allow(unused_imports)]
use quantwave_core::traits::Next; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum BacktestError {
#[error("Polars error during simulation: {0}")]
Polars(#[from] PolarsError),
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Data must be sorted by timestamp (and symbol for multi-symbol runs)")]
UnsortedData,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostModel {
pub commission_bps: f64,
pub slippage_bps: f64,
pub initial_cash: f64,
}
impl Default for CostModel {
fn default() -> Self {
Self {
commission_bps: 5.0, slippage_bps: 2.0, initial_cash: 100_000.0,
}
}
}
pub trait CommissionModel: Send + Sync + std::fmt::Debug {
fn calculate_commission(&self, fill_quantity: f64, fill_price: f64) -> f64;
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct BpsCommissionModel {
pub bps: f64,
}
impl CommissionModel for BpsCommissionModel {
fn calculate_commission(&self, fill_quantity: f64, fill_price: f64) -> f64 {
(fill_quantity.abs() * fill_price) * (self.bps / 10_000.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct FixedPerShareCommissionModel {
pub per_share: f64,
}
impl CommissionModel for FixedPerShareCommissionModel {
fn calculate_commission(&self, fill_quantity: f64, _fill_price: f64) -> f64 {
fill_quantity.abs() * self.per_share
}
}
pub trait SlippageModel: Send + Sync + std::fmt::Debug {
fn apply(&self, price: f64, quantity: f64, is_buy: bool, adv: Option<f64>) -> f64;
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct BpsSlippageModel {
pub bps: f64,
}
impl SlippageModel for BpsSlippageModel {
fn apply(&self, price: f64, _quantity: f64, is_buy: bool, _adv: Option<f64>) -> f64 {
let s = self.bps / 10_000.0;
if is_buy { price * (1.0 + s) } else { price * (1.0 - s) }
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct SquareRootMarketImpactSlippage {
pub impact_coef: f64,
pub max_participation: f64,
}
impl SlippageModel for SquareRootMarketImpactSlippage {
fn apply(&self, price: f64, quantity: f64, is_buy: bool, adv: Option<f64>) -> f64 {
let adv = adv.unwrap_or(1_000_000.0);
let part = (quantity.abs() / adv).min(self.max_participation);
let impact = self.impact_coef * part.sqrt();
if is_buy { price * (1.0 + impact) } else { price * (1.0 - impact) }
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct StopConfig {
pub stop_loss_pct: Option<f64>,
pub take_profit_pct: Option<f64>,
pub trailing_stop_pct: Option<f64>,
}
impl StopConfig {
pub fn has_stops(&self) -> bool {
self.stop_loss_pct.is_some()
|| self.take_profit_pct.is_some()
|| self.trailing_stop_pct.is_some()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum ExecutionDelay {
#[default]
SameBar,
NextBar,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ExecutionModel {
Simple(CostModel),
HighFidelity {
commission: BpsCommissionModel,
slippage: SquareRootMarketImpactSlippage,
},
}
impl Default for ExecutionModel {
fn default() -> Self {
ExecutionModel::Simple(CostModel::default())
}
}
impl ExecutionModel {
pub fn commission_for(&self, qty: f64, px: f64) -> f64 {
match self {
ExecutionModel::Simple(cm) => (qty.abs() * px) * (cm.commission_bps / 10_000.0),
ExecutionModel::HighFidelity { commission, .. } => commission.calculate_commission(qty, px),
}
}
pub fn slippage_price(&self, price: f64, qty: f64, is_buy: bool, adv: Option<f64>) -> f64 {
match self {
ExecutionModel::Simple(cm) => {
let s = cm.slippage_bps / 10_000.0;
if is_buy { price * (1.0 + s) } else { price * (1.0 - s) }
}
ExecutionModel::HighFidelity { slippage, .. } => slippage.apply(price, qty, is_buy, adv),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InitialRiskPositionSizer {
pub initial_risk: f64,
pub max_target_pct: f64,
}
impl Default for InitialRiskPositionSizer {
fn default() -> Self {
Self { initial_risk: 0.01, max_target_pct: 0.25 }
}
}
impl InitialRiskPositionSizer {
pub fn compute_sized_exposure(
&self,
raw_exposure: f64,
meta: &Option<HashMap<String, f64>>,
price: f64,
equity: f64,
) -> f64 {
let sign = if raw_exposure > 0.0 { 1.0 } else if raw_exposure < 0.0 { -1.0 } else { 0.0 };
if let Some(m) = meta {
if let Some(frac) = m.get("fraction_at_risk").copied() {
if frac > 0.0 {
let target_pct = (self.initial_risk / frac).min(self.max_target_pct);
let target_units = target_pct * equity / price * sign;
return target_units;
}
}
if let Some(pole) = m.get("pole_height_atr").copied() {
if pole > 0.0 {
let frac = 0.01 / pole;
let target_pct = (self.initial_risk / frac).min(self.max_target_pct);
let target_units = target_pct * equity / price * sign;
return target_units;
}
}
}
raw_exposure
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BacktestConfig {
pub cost_model: CostModel,
pub timestamp_col: String,
pub symbol_col: Option<String>,
pub close_col: String,
pub signal_col: String,
pub entry_filter_col: Option<String>,
pub size_multiplier_col: Option<String>,
pub execution_model: ExecutionModel,
pub execution_delay: ExecutionDelay,
pub stop_config: StopConfig,
pub position_sizer: Option<InitialRiskPositionSizer>,
}
impl Default for BacktestConfig {
fn default() -> Self {
Self {
cost_model: CostModel::default(),
timestamp_col: "timestamp".to_string(),
symbol_col: None,
close_col: "close".to_string(),
signal_col: "signal".to_string(),
entry_filter_col: None,
size_multiplier_col: None,
execution_model: ExecutionModel::default(),
execution_delay: ExecutionDelay::default(),
stop_config: StopConfig::default(),
position_sizer: None,
}
}
}
fn signal_bar_index(bar: usize, delay: ExecutionDelay) -> Option<usize> {
match delay {
ExecutionDelay::SameBar => Some(bar),
ExecutionDelay::NextBar => bar.checked_sub(1),
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Trade {
pub trade_id: u32,
pub symbol: Option<String>,
pub side: i8, pub entry_ts: DateTime<Utc>,
pub entry_price: f64,
pub entry_fill_price: f64, pub exit_ts: Option<DateTime<Utc>>,
pub exit_price: Option<f64>,
pub exit_fill_price: Option<f64>,
pub pnl_gross: f64,
pub costs: f64,
pub pnl_net: f64,
pub quantity: f64,
pub entry_metadata: Option<HashMap<String, f64>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EquityPoint {
pub ts: DateTime<Utc>,
pub symbol: Option<String>, pub equity: f64,
pub cash: f64,
pub position: f64, pub close: f64,
}
#[derive(Debug)]
pub struct BacktestResult {
pub trades: DataFrame,
pub equity_curve: DataFrame,
pub stats: HashMap<String, f64>,
}
impl BacktestResult {
pub fn metrics(&self) -> PerformanceMetrics {
PerformanceMetrics::from_result(self)
}
}
#[derive(Debug, Clone)]
pub struct Bar {
pub ts: DateTime<Utc>,
pub close: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StrategySignal {
pub exposure: f64,
pub metadata: Option<HashMap<String, f64>>,
}
impl Default for StrategySignal {
fn default() -> Self {
Self {
exposure: 0.0,
metadata: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct PAEvent {
pub long: bool,
pub pole_height: Option<f64>,
pub strength: Option<f64>,
}
impl PAEvent {
pub fn to_strategy_signal(&self) -> StrategySignal {
let mut meta = HashMap::new();
if let Some(p) = self.pole_height {
meta.insert("pole_height".to_string(), p);
}
if let Some(s) = self.strength {
meta.insert("strength".to_string(), s);
}
let exposure = if self.long {
self.pole_height
.map(pole_height_to_exposure)
.unwrap_or(1.0)
} else {
0.0
};
StrategySignal {
exposure,
metadata: if meta.is_empty() { None } else { Some(meta) },
}
}
}
pub fn pole_height_to_exposure(pole_height: f64) -> f64 {
(pole_height / 4.0).clamp(0.4, 2.2)
}
pub fn parse_struct_signal_row(
ca: &StructChunked,
i: usize,
) -> Result<(f64, Option<HashMap<String, f64>>), BacktestError> {
let mut meta = HashMap::new();
let exposure_direct = struct_field_f64(ca, "exposure", i);
let long = struct_field_bool(ca, "long", i);
let short = struct_field_bool(ca, "short", i);
if let DataType::Struct(fields) = ca.dtype() {
for field in fields {
let key = field.name.as_str();
if matches!(key, "exposure" | "long" | "short") {
continue;
}
if let Some(v) = struct_field_f64(ca, key, i) {
if v.is_finite() {
meta.insert(key.to_string(), v);
}
}
}
}
let pole = ["pole_height", "pole_height_atr", "pole_length_atr"]
.iter()
.find_map(|name| meta.get(*name).copied())
.filter(|v| *v > 0.0);
let exposure = if let Some(e) = exposure_direct {
if e.is_finite() && e != 0.0 {
e
} else if short.unwrap_or(false) {
let mag = pole.map(pole_height_to_exposure).unwrap_or(1.0);
-mag
} else if long.unwrap_or(false) {
pole.map(pole_height_to_exposure).unwrap_or(1.0)
} else {
0.0
}
} else if short.unwrap_or(false) {
let mag = pole.map(pole_height_to_exposure).unwrap_or(1.0);
-mag
} else if long.unwrap_or(false) {
pole.map(pole_height_to_exposure).unwrap_or(1.0)
} else {
0.0
};
let metadata = if meta.is_empty() { None } else { Some(meta) };
Ok((exposure, metadata))
}
fn struct_field_f64(ca: &StructChunked, name: &str, i: usize) -> Option<f64> {
let field = ca.field_by_name(name).ok()?;
field.f64().ok().and_then(|arr| arr.get(i))
}
fn struct_field_bool(ca: &StructChunked, name: &str, i: usize) -> Option<bool> {
let field = ca.field_by_name(name).ok()?;
field.bool().ok().and_then(|arr| arr.get(i))
}
pub struct BacktestEngine {
config: BacktestConfig,
}
impl BacktestEngine {
pub fn new(config: BacktestConfig) -> Self {
Self { config }
}
pub fn with_default_costs() -> Self {
Self::new(BacktestConfig::default())
}
pub fn backtest_with_report(&self, lf: LazyFrame) -> Result<BacktestReport, BacktestError> {
let result = self.run(lf)?;
let metrics = PerformanceMetrics::from_result(&result);
Ok(BacktestReport { result, metrics })
}
pub fn run(&self, lf: LazyFrame) -> Result<BacktestResult, BacktestError> {
let df = lf.collect()?;
if df.height() == 0 {
return Err(BacktestError::InvalidInput("empty dataframe".into()));
}
let ts_col = &self.config.timestamp_col;
let close_col = &self.config.close_col;
let sig_col = &self.config.signal_col;
for c in [ts_col, close_col, sig_col] {
if df.column(c).is_err() {
return Err(BacktestError::InvalidInput(format!(
"missing column: {}",
c
)));
}
}
if self.config.symbol_col.is_some() {
return self.run_multi_symbol(df);
}
self.run_single_symbol(df)
}
pub fn run_metrics_only(&self, lf: LazyFrame) -> Result<PerformanceMetrics, BacktestError> {
let df = lf.collect()?;
if df.height() == 0 {
return Err(BacktestError::InvalidInput("empty dataframe".into()));
}
let ts_col = &self.config.timestamp_col;
let close_col = &self.config.close_col;
let sig_col = &self.config.signal_col;
for c in [ts_col, close_col, sig_col] {
if df.column(c).is_err() {
return Err(BacktestError::InvalidInput(format!(
"missing column: {}",
c
)));
}
}
if self.config.symbol_col.is_some() {
return self.run_metrics_multi_symbol(df);
}
self.run_metrics_single_symbol(df)
}
fn run_metrics_single_symbol(&self, df: DataFrame) -> Result<PerformanceMetrics, BacktestError> {
let (trades, equity_points) = self.simulate_dataframe(&df, None)?;
Ok(PerformanceMetrics::from_raw(&trades, &equity_points, self.per_symbol_initial_cash()))
}
fn run_metrics_multi_symbol(&self, df: DataFrame) -> Result<PerformanceMetrics, BacktestError> {
let sym_col = self
.config
.symbol_col
.as_ref()
.expect("symbol_col set");
if df.column(sym_col).is_err() {
return Err(BacktestError::InvalidInput(format!(
"missing column: {}",
sym_col
)));
}
let ts_series = df.column(&self.config.timestamp_col)?.clone();
let timestamps = self.extract_timestamps(&ts_series)?;
let symbols = extract_string_column(df.column(sym_col)?.clone())?;
validate_sorted_timestamp_symbol(×tamps, &symbols)?;
let mut unique_symbols: Vec<String> = Vec::new();
let mut seen = std::collections::HashSet::new();
for s in &symbols {
if seen.insert(s.clone()) {
unique_symbols.push(s.clone());
}
}
let mut all_trades: Vec<Trade> = Vec::new();
let mut per_symbol_equity: HashMap<String, Vec<EquityPoint>> = HashMap::new();
for symbol in &unique_symbols {
let sub = df
.clone()
.lazy()
.filter(col(sym_col).eq(lit(symbol.as_str())))
.sort(
[&self.config.timestamp_col],
SortMultipleOptions::default(),
)
.collect()?;
let (mut trades, equity_points) = self.simulate_dataframe(&sub, Some(symbol))?;
all_trades.append(&mut trades);
per_symbol_equity.insert(symbol.clone(), equity_points);
}
let portfolio_equity = aggregate_portfolio_equity(&per_symbol_equity);
let n_symbols = unique_symbols.len() as f64;
let portfolio_initial = self.per_symbol_initial_cash() * n_symbols;
Ok(PerformanceMetrics::from_raw(&all_trades, &portfolio_equity, portfolio_initial))
}
fn run_single_symbol(&self, df: DataFrame) -> Result<BacktestResult, BacktestError> {
let (trades, equity_points) = self.simulate_dataframe(&df, None)?;
let initial_cash = self.per_symbol_initial_cash();
let final_equity = equity_points
.last()
.map(|e| e.equity)
.unwrap_or(initial_cash);
let total_return = (final_equity - initial_cash) / initial_cash;
let num_trades = trades.len() as f64;
let mut stats = HashMap::new();
stats.insert("initial_cash".to_string(), initial_cash);
stats.insert("final_equity".to_string(), final_equity);
stats.insert("total_return".to_string(), total_return);
stats.insert("num_trades".to_string(), num_trades);
stats.insert("net_pnl".to_string(), final_equity - initial_cash);
Ok(BacktestResult {
trades: self.trades_to_df(&trades, false)?,
equity_curve: self.equity_to_df(&equity_points, false)?,
stats,
})
}
fn run_multi_symbol(&self, df: DataFrame) -> Result<BacktestResult, BacktestError> {
let sym_col = self
.config
.symbol_col
.as_ref()
.expect("symbol_col set");
if df.column(sym_col).is_err() {
return Err(BacktestError::InvalidInput(format!(
"missing column: {}",
sym_col
)));
}
let ts_series = df.column(&self.config.timestamp_col)?.clone();
let timestamps = self.extract_timestamps(&ts_series)?;
let symbols = extract_string_column(df.column(sym_col)?.clone())?;
validate_sorted_timestamp_symbol(×tamps, &symbols)?;
let mut unique_symbols: Vec<String> = Vec::new();
let mut seen = std::collections::HashSet::new();
for s in &symbols {
if seen.insert(s.clone()) {
unique_symbols.push(s.clone());
}
}
let per_symbol_initial = self.per_symbol_initial_cash();
let mut all_trades: Vec<Trade> = Vec::new();
let mut per_symbol_equity: HashMap<String, Vec<EquityPoint>> = HashMap::new();
for symbol in &unique_symbols {
let sub = df
.clone()
.lazy()
.filter(col(sym_col).eq(lit(symbol.as_str())))
.sort(
[&self.config.timestamp_col],
SortMultipleOptions::default(),
)
.collect()?;
let (mut trades, equity_points) = self.simulate_dataframe(&sub, Some(symbol))?;
all_trades.append(&mut trades);
per_symbol_equity.insert(symbol.clone(), equity_points);
}
let portfolio_equity = aggregate_portfolio_equity(&per_symbol_equity);
let mut combined_equity: Vec<EquityPoint> = per_symbol_equity
.values()
.flatten()
.cloned()
.collect();
combined_equity.extend(portfolio_equity.clone());
let n_symbols = unique_symbols.len() as f64;
let portfolio_initial = per_symbol_initial * n_symbols;
let portfolio_final = portfolio_equity
.last()
.map(|e| e.equity)
.unwrap_or(portfolio_initial);
let total_return = (portfolio_final - portfolio_initial) / portfolio_initial;
let num_trades = all_trades.len() as f64;
let mut stats = HashMap::new();
stats.insert("initial_cash".to_string(), portfolio_initial);
stats.insert("final_equity".to_string(), portfolio_final);
stats.insert("total_return".to_string(), total_return);
stats.insert("num_trades".to_string(), num_trades);
stats.insert("net_pnl".to_string(), portfolio_final - portfolio_initial);
stats.insert("num_symbols".to_string(), n_symbols);
Ok(BacktestResult {
trades: self.trades_to_df(&all_trades, true)?,
equity_curve: self.equity_to_df(&combined_equity, true)?,
stats,
})
}
fn per_symbol_initial_cash(&self) -> f64 {
match &self.config.execution_model {
ExecutionModel::Simple(cm) => cm.initial_cash,
_ => 100_000.0,
}
}
fn simulate_dataframe(
&self,
df: &DataFrame,
symbol: Option<&str>,
) -> Result<(Vec<Trade>, Vec<EquityPoint>), BacktestError> {
let ts_col = &self.config.timestamp_col;
let close_col = &self.config.close_col;
let sig_col = &self.config.signal_col;
let ts_series = df.column(ts_col)?.clone();
let close_ca = df.column(close_col)?.f64()?.clone();
let (signal_vals, signal_metas) = self.load_signals(df, sig_col)?;
let entry_filters = self.load_entry_filters(df)?;
let size_multipliers = self.load_size_multipliers(df)?;
let n = signal_vals.len();
if let Some(ref f) = entry_filters {
if f.len() != n {
return Err(BacktestError::InvalidInput(
"entry_filter column length mismatch".into(),
));
}
}
if let Some(ref m) = size_multipliers {
if m.len() != n {
return Err(BacktestError::InvalidInput(
"size_multiplier column length mismatch".into(),
));
}
}
let effective_signals: Vec<f64> = signal_vals
.iter()
.enumerate()
.map(|(i, &raw)| {
apply_signal_modifiers(
raw,
entry_filters.as_ref().map(|f| f[i]),
size_multipliers.as_ref().map(|m| m[i]),
)
})
.collect();
let timestamps = self.extract_timestamps(&ts_series)?;
let closes: Vec<f64> = close_ca
.into_iter()
.map(|v| v.unwrap_or(f64::NAN))
.collect();
if timestamps.len() != closes.len() || closes.len() != effective_signals.len() {
return Err(BacktestError::InvalidInput("column length mismatch".into()));
}
let exec = &self.config.execution_model;
let sizer = &self.config.position_sizer;
let mut effective_metas: Vec<Option<HashMap<String, f64>>> =
Vec::with_capacity(effective_signals.len());
for (i, &raw) in effective_signals.iter().enumerate() {
if raw == 0.0 {
effective_metas.push(None);
} else {
effective_metas.push(signal_metas.get(i).cloned().flatten());
}
}
let delay = self.config.execution_delay;
let stops = &self.config.stop_config;
let (mut trades, mut equity_points) = run_simulation(
×tamps,
&closes,
|i| (effective_signals[i], effective_metas[i].clone()),
exec,
sizer,
delay,
stops,
);
if let Some(sym) = symbol {
let sym_owned = sym.to_string();
for t in &mut trades {
t.symbol = Some(sym_owned.clone());
}
for e in &mut equity_points {
e.symbol = Some(sym_owned.clone());
}
}
Ok((trades, equity_points))
}
fn load_signals(
&self,
df: &DataFrame,
sig_col: &str,
) -> Result<(Vec<f64>, Vec<Option<HashMap<String, f64>>>), BacktestError> {
let signal_series = df.column(sig_col)?;
let s = signal_series
.as_series()
.ok_or_else(|| BacktestError::InvalidInput("column has no series backing".into()))?;
if s.dtype().is_struct() {
let ca = s.struct_().map_err(|e| BacktestError::Polars(e))?;
let n = ca.len();
let mut exposures = Vec::with_capacity(n);
let mut metas = Vec::with_capacity(n);
for i in 0..n {
let (exp, meta) = parse_struct_signal_row(ca, i)?;
exposures.push(exp);
metas.push(meta);
}
return Ok((exposures, metas));
}
let signal_vals: Vec<f64> = if signal_series.dtype().is_bool() {
signal_series
.bool()?
.into_iter()
.map(|b| if b.unwrap_or(false) { 1.0 } else { 0.0 })
.collect()
} else {
signal_series
.f64()?
.into_iter()
.map(|v| v.unwrap_or(0.0))
.collect()
};
let metas = vec![None; signal_vals.len()];
Ok((signal_vals, metas))
}
fn load_entry_filters(&self, df: &DataFrame) -> Result<Option<Vec<bool>>, BacktestError> {
let Some(col_name) = &self.config.entry_filter_col else {
return Ok(None);
};
if df.column(col_name).is_err() {
return Err(BacktestError::InvalidInput(format!(
"missing column: {}",
col_name
)));
}
extract_bool_column(df.column(col_name)?.clone())
.map(Some)
}
fn load_size_multipliers(&self, df: &DataFrame) -> Result<Option<Vec<f64>>, BacktestError> {
let Some(col_name) = &self.config.size_multiplier_col else {
return Ok(None);
};
if df.column(col_name).is_err() {
return Err(BacktestError::InvalidInput(format!(
"missing column: {}",
col_name
)));
}
extract_f64_column(df.column(col_name)?.clone())
.map(Some)
}
fn extract_timestamps(&self, col: &Column) -> Result<Vec<DateTime<Utc>>, BacktestError> {
let s = col
.as_series()
.ok_or_else(|| BacktestError::InvalidInput("column has no series backing".into()))?;
if let Ok(ca) = s.datetime() {
return Ok(ca
.into_iter()
.map(|opt| {
opt.map(|v| {
let secs = v / 1000;
let nanos = ((v % 1000) * 1_000_000) as u32;
DateTime::<Utc>::from_timestamp(secs, nanos).unwrap_or_else(Utc::now)
})
.unwrap_or_else(Utc::now)
})
.collect());
}
if let Ok(ca) = s.i64() {
return Ok(ca
.into_iter()
.enumerate()
.map(|(i, opt)| {
let v = opt.unwrap_or(i as i64);
DateTime::<Utc>::from_timestamp(v, 0).unwrap_or_else(Utc::now)
})
.collect());
}
Err(BacktestError::InvalidInput(
"timestamp column must be Datetime or Int64 for this MVP".into(),
))
}
fn trades_to_df(&self, trades: &[Trade], include_symbol: bool) -> Result<DataFrame, PolarsError> {
if trades.is_empty() {
let mut cols = vec![
Column::new("trade_id".into(), Vec::<u32>::new()),
Column::new("side".into(), Vec::<i8>::new()),
Column::new("entry_ts".into(), Vec::<i64>::new()),
Column::new("entry_price".into(), Vec::<f64>::new()),
Column::new("pnl_net".into(), Vec::<f64>::new()),
];
if include_symbol {
cols.push(Column::new("symbol".into(), Vec::<Option<String>>::new()));
}
return Ok(DataFrame::new(cols)?);
}
let ids: Vec<u32> = trades.iter().map(|t| t.trade_id).collect();
let sides: Vec<i8> = trades.iter().map(|t| t.side).collect();
let entry_ts: Vec<i64> = trades.iter().map(|t| t.entry_ts.timestamp()).collect();
let entry_px: Vec<f64> = trades.iter().map(|t| t.entry_price).collect();
let exit_ts: Vec<Option<i64>> = trades
.iter()
.map(|t| t.exit_ts.map(|d| d.timestamp()))
.collect();
let exit_px: Vec<Option<f64>> = trades.iter().map(|t| t.exit_price).collect();
let qty: Vec<f64> = trades.iter().map(|t| t.quantity).collect();
let pnl: Vec<f64> = trades.iter().map(|t| t.pnl_net).collect();
let mut cols = vec![
Column::new("trade_id".into(), ids),
Column::new("side".into(), sides),
Column::new("entry_ts".into(), entry_ts),
Column::new("entry_price".into(), entry_px),
Column::new("exit_ts".into(), exit_ts),
Column::new("exit_price".into(), exit_px),
Column::new("quantity".into(), qty),
Column::new("pnl_net".into(), pnl),
];
if include_symbol {
let symbols: Vec<Option<String>> = trades.iter().map(|t| t.symbol.clone()).collect();
cols.push(Column::new("symbol".into(), symbols));
}
DataFrame::new(cols)
}
fn equity_to_df(&self, points: &[EquityPoint], include_symbol: bool) -> Result<DataFrame, PolarsError> {
if points.is_empty() {
let mut cols = vec![
Column::new("ts".into(), Vec::<i64>::new()),
Column::new("equity".into(), Vec::<f64>::new()),
Column::new("position".into(), Vec::<f64>::new()),
];
if include_symbol {
cols.push(Column::new("symbol".into(), Vec::<Option<String>>::new()));
}
return Ok(DataFrame::new(cols)?);
}
let ts: Vec<i64> = points.iter().map(|p| p.ts.timestamp()).collect();
let eq: Vec<f64> = points.iter().map(|p| p.equity).collect();
let pos: Vec<f64> = points.iter().map(|p| p.position).collect();
let cash: Vec<f64> = points.iter().map(|p| p.cash).collect();
let close: Vec<f64> = points.iter().map(|p| p.close).collect();
let mut cols = vec![
Column::new("ts".into(), ts),
Column::new("equity".into(), eq),
Column::new("cash".into(), cash),
Column::new("position".into(), pos),
Column::new("close".into(), close),
];
if include_symbol {
let symbols: Vec<Option<String>> = points.iter().map(|p| p.symbol.clone()).collect();
cols.push(Column::new("symbol".into(), symbols));
}
DataFrame::new(cols)
}
}
pub fn apply_signal_modifiers(
raw_signal: f64,
entry_filter: Option<bool>,
size_multiplier: Option<f64>,
) -> f64 {
if matches!(entry_filter, Some(false)) {
return 0.0;
}
let mut exposure = raw_signal;
if let Some(m) = size_multiplier {
exposure *= m;
}
if exposure.is_finite() && exposure != 0.0 {
exposure
} else {
0.0
}
}
fn extract_bool_column(col: Column) -> Result<Vec<bool>, BacktestError> {
let s = col
.as_series()
.ok_or_else(|| BacktestError::InvalidInput("column has no series backing".into()))?;
if let Ok(ca) = s.bool() {
return Ok(ca
.into_iter()
.map(|opt| opt.unwrap_or(false))
.collect());
}
Err(BacktestError::InvalidInput(
"entry_filter column must be boolean".into(),
))
}
fn extract_f64_column(col: Column) -> Result<Vec<f64>, BacktestError> {
let s = col
.as_series()
.ok_or_else(|| BacktestError::InvalidInput("column has no series backing".into()))?;
if let Ok(ca) = s.f64() {
return Ok(ca.into_iter().map(|opt| opt.unwrap_or(0.0)).collect());
}
Err(BacktestError::InvalidInput(
"size_multiplier column must be f64".into(),
))
}
fn extract_string_column(col: Column) -> Result<Vec<String>, BacktestError> {
let s = col
.as_series()
.ok_or_else(|| BacktestError::InvalidInput("column has no series backing".into()))?;
if let Ok(ca) = s.str() {
return Ok(ca
.into_iter()
.map(|opt| opt.unwrap_or_default().to_string())
.collect());
}
Err(BacktestError::InvalidInput(
"symbol column must be Utf8/String".into(),
))
}
fn validate_sorted_timestamp_symbol(
timestamps: &[DateTime<Utc>],
symbols: &[String],
) -> Result<(), BacktestError> {
if timestamps.len() != symbols.len() {
return Err(BacktestError::InvalidInput("column length mismatch".into()));
}
for i in 1..timestamps.len() {
let prev = (×tamps[i - 1], &symbols[i - 1]);
let curr = (×tamps[i], &symbols[i]);
if curr < prev {
return Err(BacktestError::UnsortedData);
}
}
Ok(())
}
fn aggregate_portfolio_equity(per_symbol: &HashMap<String, Vec<EquityPoint>>) -> Vec<EquityPoint> {
use std::collections::BTreeSet;
let mut ts_set = BTreeSet::new();
for points in per_symbol.values() {
for p in points {
ts_set.insert(p.ts);
}
}
ts_set
.into_iter()
.map(|ts| {
let mut total_equity = 0.0;
let mut total_cash = 0.0;
let mut total_position = 0.0;
for points in per_symbol.values() {
if let Some(p) = points.iter().find(|p| p.ts == ts) {
total_equity += p.equity;
total_cash += p.cash;
total_position += p.position;
}
}
EquityPoint {
ts,
symbol: None,
equity: total_equity,
cash: total_cash,
position: total_position,
close: 0.0,
}
})
.collect()
}
pub fn backtest_simple_bool_signal(
ohlcv: DataFrame,
signal_col: &str,
) -> Result<BacktestResult, BacktestError> {
let config = BacktestConfig {
signal_col: signal_col.to_string(),
..Default::default()
};
let engine = BacktestEngine::new(config);
engine.run(ohlcv.lazy())
}
fn run_simulation(
timestamps: &[DateTime<Utc>],
closes: &[f64],
mut next_signal: impl FnMut(usize) -> (f64, Option<HashMap<String, f64>>),
exec: &ExecutionModel,
sizer: &Option<InitialRiskPositionSizer>,
execution_delay: ExecutionDelay,
stop_config: &StopConfig,
) -> (Vec<Trade>, Vec<EquityPoint>) {
let mut cash = match exec {
ExecutionModel::Simple(cm) => cm.initial_cash,
ExecutionModel::HighFidelity { .. } => 100_000.0,
};
let mut current_exposure: f64 = 0.0;
let mut entry_price: f64 = 0.0;
let mut entry_ts: Option<DateTime<Utc>> = None;
let mut entry_metadata: Option<HashMap<String, f64>> = None;
let mut trailing_stop_level: Option<f64> = None;
let mut need_signal_reset = false;
let mut trade_id: u32 = 0;
let mut trades: Vec<Trade> = Vec::new();
let mut equity_points: Vec<EquityPoint> = Vec::with_capacity(closes.len());
let mut record_position_exit =
|cash: &mut f64,
tid: u32,
side: i8,
qty: f64,
entry_px: f64,
ets: DateTime<Utc>,
exit_bar: usize,
meta: Option<HashMap<String, f64>>| {
let close = closes[exit_bar];
let is_buy = side == -1;
let fill_price = exec.slippage_price(close, qty, is_buy, None);
let notional = fill_price * qty;
let cost = exec.commission_for(qty, fill_price);
let gross_pnl = if side == 1 {
(fill_price - entry_px) * qty
} else {
(entry_px - fill_price) * qty
};
let net_pnl = gross_pnl - cost;
if side == 1 {
*cash += notional - cost;
} else {
*cash -= notional + cost;
}
trades.push(Trade {
trade_id: tid,
symbol: None,
side,
entry_ts: ets,
entry_price: entry_px,
entry_fill_price: entry_px,
exit_ts: Some(timestamps[exit_bar]),
exit_price: Some(close),
exit_fill_price: Some(fill_price),
pnl_gross: gross_pnl,
costs: cost,
pnl_net: net_pnl,
quantity: qty,
entry_metadata: meta,
});
};
let open_position = |cash: &mut f64,
tid: u32,
desired: f64,
fill_bar: usize,
meta: Option<HashMap<String, f64>>|
-> (u32, f64, f64, Option<DateTime<Utc>>, Option<HashMap<String, f64>>, Option<f64>) {
let qty = desired.abs();
let is_long = desired > 0.0;
let is_buy = is_long;
let close = closes[fill_bar];
let fill_price = exec.slippage_price(close, qty, is_buy, None);
let notional = fill_price * qty;
let cost = exec.commission_for(qty, fill_price);
if is_long {
*cash -= notional + cost;
} else {
*cash += notional - cost;
}
let new_tid = tid + 1;
let exposure = if is_long { qty } else { -qty };
let trail = stop_config.trailing_stop_pct.map(|pct| {
if is_long {
fill_price * (1.0 - pct)
} else {
fill_price * (1.0 + pct)
}
});
(
new_tid,
exposure,
fill_price,
Some(timestamps[fill_bar]),
meta,
trail,
)
};
for i in 0..closes.len() {
let close = closes[i];
if !close.is_finite() {
let equity = cash + current_exposure * close;
equity_points.push(EquityPoint {
ts: timestamps[i],
symbol: None,
equity,
cash,
position: current_exposure,
close,
});
continue;
}
if current_exposure != 0.0 && stop_config.has_stops() {
let is_long = current_exposure > 0.0;
let qty = current_exposure.abs();
if let Some(trail_pct) = stop_config.trailing_stop_pct {
if is_long {
let new_level = close * (1.0 - trail_pct);
trailing_stop_level = Some(match trailing_stop_level {
Some(prev) => prev.max(new_level),
None => new_level,
});
} else {
let new_level = close * (1.0 + trail_pct);
trailing_stop_level = Some(match trailing_stop_level {
Some(prev) => prev.min(new_level),
None => new_level,
});
}
}
let mut stop_out = false;
if is_long {
if let Some(tp) = stop_config.take_profit_pct {
if close >= entry_price * (1.0 + tp) {
stop_out = true;
}
}
if !stop_out {
let mut effective_stop = f64::NEG_INFINITY;
if let Some(sl) = stop_config.stop_loss_pct {
effective_stop = entry_price * (1.0 - sl);
}
if let Some(level) = trailing_stop_level {
effective_stop = effective_stop.max(level);
}
if effective_stop > f64::NEG_INFINITY && close <= effective_stop {
stop_out = true;
}
}
} else {
if let Some(tp) = stop_config.take_profit_pct {
if close <= entry_price * (1.0 - tp) {
stop_out = true;
}
}
if !stop_out {
let mut effective_stop = f64::INFINITY;
if let Some(sl) = stop_config.stop_loss_pct {
effective_stop = entry_price * (1.0 + sl);
}
if let Some(level) = trailing_stop_level {
effective_stop = effective_stop.min(level);
}
if effective_stop < f64::INFINITY && close >= effective_stop {
stop_out = true;
}
}
}
if stop_out {
if let Some(ets) = entry_ts.take() {
let side = if is_long { 1 } else { -1 };
record_position_exit(
&mut cash,
trade_id,
side,
qty,
entry_price,
ets,
i,
entry_metadata.clone(),
);
current_exposure = 0.0;
entry_price = 0.0;
trailing_stop_level = None;
entry_metadata = None;
need_signal_reset = true;
}
}
}
let (raw_exposure, meta) = match signal_bar_index(i, execution_delay) {
Some(si) => next_signal(si),
None => (0.0, None),
};
let current_equity = cash + current_exposure * close;
let desired_exposure = if let Some(s) = sizer {
s.compute_sized_exposure(raw_exposure, &meta, close, current_equity)
} else {
raw_exposure
};
let desired = if desired_exposure.is_finite() && desired_exposure != 0.0 {
desired_exposure
} else {
0.0
};
if desired == 0.0 {
need_signal_reset = false;
}
let currently_in = current_exposure != 0.0;
if desired == 0.0 && currently_in {
if let Some(ets) = entry_ts.take() {
let side = if current_exposure > 0.0 { 1 } else { -1 };
record_position_exit(
&mut cash,
trade_id,
side,
current_exposure.abs(),
entry_price,
ets,
i,
meta.clone(),
);
current_exposure = 0.0;
entry_price = 0.0;
trailing_stop_level = None;
entry_metadata = None;
}
} else if desired != 0.0 && !need_signal_reset {
let want_long = desired > 0.0;
let in_long = current_exposure > 0.0;
let in_short = current_exposure < 0.0;
let flip = (want_long && in_short) || (!want_long && in_long);
if flip {
if let Some(ets) = entry_ts.take() {
let side = if in_long { 1 } else { -1 };
record_position_exit(
&mut cash,
trade_id,
side,
current_exposure.abs(),
entry_price,
ets,
i,
entry_metadata.clone(),
);
current_exposure = 0.0;
entry_price = 0.0;
trailing_stop_level = None;
entry_metadata = None;
}
}
if current_exposure == 0.0 {
let (new_tid, exp, ep, ets, em, trail) =
open_position(&mut cash, trade_id, desired, i, meta.clone());
trade_id = new_tid;
current_exposure = exp;
entry_price = ep;
entry_ts = ets;
entry_metadata = em;
trailing_stop_level = trail;
}
}
let equity = cash + current_exposure * close;
equity_points.push(EquityPoint {
ts: timestamps[i],
symbol: None,
equity,
cash,
position: current_exposure,
close,
});
}
if current_exposure != 0.0 {
let last_close = *closes.last().unwrap();
let qty = current_exposure.abs();
let side = if current_exposure > 0.0 { 1 } else { -1 };
let gross = if side == 1 {
(last_close - entry_price) * qty
} else {
(entry_price - last_close) * qty
};
if let Some(ets) = entry_ts {
trades.push(Trade {
trade_id,
symbol: None,
side,
entry_ts: ets,
entry_price,
entry_fill_price: entry_price,
exit_ts: None,
exit_price: Some(last_close),
exit_fill_price: None,
pnl_gross: gross,
costs: 0.0,
pnl_net: gross,
quantity: qty,
entry_metadata: None,
});
}
}
(trades, equity_points)
}
pub fn run_streaming_simulation<G>(
bars: &[Bar],
mut generator: G,
config: BacktestConfig,
) -> Result<BacktestResult, BacktestError>
where
G: for<'a> Next<&'a Bar, Output = StrategySignal>,
{
if bars.is_empty() {
return Err(BacktestError::InvalidInput("empty bars".into()));
}
let timestamps: Vec<DateTime<Utc>> = bars.iter().map(|b| b.ts).collect();
let closes: Vec<f64> = bars.iter().map(|b| b.close).collect();
let exec = &config.execution_model;
let sizer = &config.position_sizer;
let delay = config.execution_delay;
let stops = &config.stop_config;
let (trades, equity_points) = run_simulation(
×tamps,
&closes,
|i| {
let sig = generator.next(&bars[i]);
(sig.exposure, sig.metadata.clone())
},
exec,
sizer,
delay,
stops,
);
let trades_df = if trades.is_empty() {
DataFrame::new(vec![
Column::new("trade_id".into(), Vec::<u32>::new()),
Column::new("side".into(), Vec::<i8>::new()),
Column::new("entry_ts".into(), Vec::<i64>::new()),
Column::new("entry_price".into(), Vec::<f64>::new()),
Column::new("pnl_net".into(), Vec::<f64>::new()),
])?
} else {
let ids: Vec<u32> = trades.iter().map(|t| t.trade_id).collect();
let sides: Vec<i8> = trades.iter().map(|t| t.side).collect();
let entry_ts: Vec<i64> = trades.iter().map(|t| t.entry_ts.timestamp()).collect();
let entry_px: Vec<f64> = trades.iter().map(|t| t.entry_price).collect();
let exit_ts: Vec<Option<i64>> = trades
.iter()
.map(|t| t.exit_ts.map(|d| d.timestamp()))
.collect();
let exit_px: Vec<Option<f64>> = trades.iter().map(|t| t.exit_price).collect();
let pnl: Vec<f64> = trades.iter().map(|t| t.pnl_net).collect();
DataFrame::new(vec![
Column::new("trade_id".into(), ids),
Column::new("side".into(), sides),
Column::new("entry_ts".into(), entry_ts),
Column::new("entry_price".into(), entry_px),
Column::new("exit_ts".into(), exit_ts),
Column::new("exit_price".into(), exit_px),
Column::new("pnl_net".into(), pnl),
])?
};
let equity_df = if equity_points.is_empty() {
DataFrame::new(vec![
Column::new("ts".into(), Vec::<i64>::new()),
Column::new("equity".into(), Vec::<f64>::new()),
Column::new("position".into(), Vec::<f64>::new()),
])?
} else {
let ts: Vec<i64> = equity_points.iter().map(|p| p.ts.timestamp()).collect();
let eq: Vec<f64> = equity_points.iter().map(|p| p.equity).collect();
let pos: Vec<f64> = equity_points.iter().map(|p| p.position).collect();
let cash: Vec<f64> = equity_points.iter().map(|p| p.cash).collect();
let close: Vec<f64> = equity_points.iter().map(|p| p.close).collect();
DataFrame::new(vec![
Column::new("ts".into(), ts),
Column::new("equity".into(), eq),
Column::new("cash".into(), cash),
Column::new("position".into(), pos),
Column::new("close".into(), close),
])?
};
let initial_cash = match &config.execution_model {
ExecutionModel::Simple(cm) => cm.initial_cash,
_ => 100_000.0,
};
let final_equity = equity_points
.last()
.map(|e| e.equity)
.unwrap_or(initial_cash);
let total_return = (final_equity - initial_cash) / initial_cash;
let num_trades = trades.len() as f64;
let mut stats = HashMap::new();
stats.insert("initial_cash".to_string(), initial_cash);
stats.insert("final_equity".to_string(), final_equity);
stats.insert("total_return".to_string(), total_return);
stats.insert("num_trades".to_string(), num_trades);
stats.insert("net_pnl".to_string(), final_equity - initial_cash);
Ok(BacktestResult {
trades: trades_df,
equity_curve: equity_df,
stats,
})
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use rand::Rng;
use quantwave_core::features::CyberCycleFeatureExtractor;
use quantwave_core::regimes::MarketRegime;
use quantwave_core::regimes::tar::TAR;
use quantwave_core::traits::Next;
use std::collections::HashMap;
#[test]
fn test_basic_long_only_flip_on_synthetic() {
let n: usize = 6;
let timestamps: Vec<i64> = (0..n)
.map(|i| 1_700_000_000i64 + (i as i64) * 3600)
.collect(); let closes = vec![100.0, 101.0, 102.5, 103.0, 102.0, 101.0];
let signals = vec![0.0, 1.0, 1.0, 1.0, 0.0, 0.0];
let df = DataFrame::new(vec![
Column::new("timestamp".into(), timestamps),
Column::new("close".into(), closes.clone()),
Column::new("signal".into(), signals),
])
.unwrap();
let result = backtest_simple_bool_signal(df, "signal").expect("sim should succeed");
assert_eq!(result.trades.height(), 1);
let num_trades: f64 = *result.stats.get("num_trades").unwrap();
assert_relative_eq!(num_trades, 1.0, epsilon = 1e-9);
let final_eq = *result.stats.get("final_equity").unwrap();
let init = 100_000.0;
assert!(
final_eq > init,
"equity should grow on winning long: {} vs {}",
final_eq,
init
);
assert_eq!(result.equity_curve.height(), n);
let last_equity = result
.equity_curve
.column("equity")
.unwrap()
.f64()
.unwrap()
.get(n - 1)
.unwrap();
assert_relative_eq!(last_equity, final_eq, epsilon = 1e-6);
}
#[test]
fn test_flat_always_signal_produces_no_trades_and_flat_equity() {
let n: usize = 5;
let ts: Vec<i64> = (0..n).map(|i| 1_700_000_100 + i as i64).collect();
let closes = vec![100.0; n];
let signals = vec![0.0; n];
let df = DataFrame::new(vec![
Column::new("timestamp".into(), ts),
Column::new("close".into(), closes),
Column::new("signal".into(), signals),
])
.unwrap();
let result = backtest_simple_bool_signal(df, "signal").unwrap();
assert_eq!(result.trades.height(), 0);
let num = *result.stats.get("num_trades").unwrap();
assert_relative_eq!(num, 0.0, epsilon = 1e-9);
let final_equity_val = *result.stats.get("final_equity").unwrap();
assert_relative_eq!(final_equity_val, 100_000.0, epsilon = 1e-4);
}
#[test]
fn test_synthetic_with_small_random_walk_and_bool_signal_matches_manual_calc() {
let mut rng = rand::thread_rng();
let n: usize = 8;
let mut price = 100.0_f64;
let mut closes = Vec::with_capacity(n);
let signals = vec![0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0]; let mut ts = Vec::with_capacity(n);
for i in 0..n {
ts.push(1_700_000_200 + i as i64);
closes.push(price);
price += rng.gen_range(-0.8..1.2);
}
let df = DataFrame::new(vec![
Column::new("timestamp".into(), ts.clone()),
Column::new("close".into(), closes.clone()),
Column::new("signal".into(), signals.clone()),
])
.unwrap();
let result = backtest_simple_bool_signal(df.clone(), "signal").unwrap();
let slip = 0.0002;
let comm = 0.0005;
let init = 100_000.0;
let mut cash = init;
let mut pos = 0.0;
let mut entry = 0.0;
let mut manual_equity = init;
for i in 0..n {
let c = closes[i];
let s = signals[i] > 0.0;
if s && pos == 0.0 {
let fp = c * (1.0 + slip);
cash -= fp * (1.0 + comm);
pos = 1.0;
entry = fp;
} else if !s && pos > 0.0 {
let fp = c * (1.0 - slip);
cash += fp * (1.0 - comm);
let _g = (fp - entry) * pos;
let cost = fp * comm;
cash += -cost; pos = 0.0;
}
manual_equity = cash + pos * c;
}
let engine_final = *result.stats.get("final_equity").unwrap();
assert_relative_eq!(engine_final, manual_equity, epsilon = 0.5);
}
#[derive(Debug, Clone)]
struct SyntheticPoleHeightDetector {
window: Vec<f64>,
max_len: usize,
}
impl SyntheticPoleHeightDetector {
fn new(max_len: usize) -> Self {
Self {
window: Vec::with_capacity(max_len),
max_len,
}
}
}
#[derive(Debug, Clone, Copy)]
struct PoleOutput {
pole_height: f64,
_strength: f64, }
impl Next<f64> for SyntheticPoleHeightDetector {
type Output = PoleOutput;
fn next(&mut self, price: f64) -> PoleOutput {
self.window.push(price);
if self.window.len() > self.max_len {
self.window.remove(0);
}
let h = if self.window.len() >= 3 {
let mn = self.window.iter().fold(f64::INFINITY, |a, &b| a.min(b));
let mx = self.window.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
(mx - mn).max(0.1)
} else {
1.0
};
PoleOutput {
pole_height: h,
_strength: (h / 8.0).clamp(0.3, 1.0),
}
}
}
#[derive(Debug, Clone)]
struct RegimeFeaturePAStrategy {
regime: TAR,
cycle: CyberCycleFeatureExtractor,
pa: SyntheticPoleHeightDetector,
feat_thresh: f64,
}
impl RegimeFeaturePAStrategy {
fn new() -> Self {
Self {
regime: TAR::new(105.0), cycle: CyberCycleFeatureExtractor::new(14),
pa: SyntheticPoleHeightDetector::new(6),
feat_thresh: 0.02,
}
}
}
impl Next<&Bar> for RegimeFeaturePAStrategy {
type Output = StrategySignal;
fn next(&mut self, bar: &Bar) -> StrategySignal {
let regime = self.regime.next(bar.close);
let feat = self.cycle.next(bar.close);
let pa = self.pa.next(bar.close);
let regime_ok = matches!(
regime,
MarketRegime::Steady | MarketRegime::Cluster(_) | MarketRegime::Bull
);
let feat_ok = feat.cycle_momentum.abs() > self.feat_thresh;
let exposure = if regime_ok && feat_ok {
(pa.pole_height / 4.0).clamp(0.4, 2.2)
} else {
0.0
};
let mut meta = HashMap::new();
meta.insert("pole_height".to_string(), pa.pole_height);
meta.insert("cycle_momentum".to_string(), feat.cycle_momentum);
meta.insert("regime_ok".to_string(), if regime_ok { 1.0 } else { 0.0 });
StrategySignal {
exposure,
metadata: Some(meta),
}
}
}
#[test]
fn test_batch_vs_streaming_parity_regime_feature_rich_pa_pole_sizing() {
let n: usize = 120;
let mut timestamps = Vec::with_capacity(n);
let mut closes = Vec::with_capacity(n);
let mut price;
for i in 0..n {
let secs = 1_700_000_500i64 + (i as i64) * 3600;
timestamps.push(chrono::DateTime::<chrono::Utc>::from_timestamp(secs, 0).unwrap());
let wave = (i as f64 * 0.18).sin() * 4.5;
price = 101.5 + wave + (i as f64 * 0.008);
closes.push(price);
}
let bars: Vec<Bar> = timestamps
.iter()
.zip(closes.iter())
.map(|(&ts, &close)| Bar { ts, close })
.collect();
let mut batch_gen = RegimeFeaturePAStrategy::new();
let mut exposures: Vec<f64> = Vec::with_capacity(n);
for bar in &bars {
let s = batch_gen.next(bar);
exposures.push(s.exposure);
}
let df = DataFrame::new(vec![
Column::new(
"timestamp".into(),
timestamps.iter().map(|t| t.timestamp()).collect::<Vec<_>>(),
),
Column::new("close".into(), closes.clone()),
Column::new("signal".into(), exposures.clone()),
])
.unwrap();
let batch_res = backtest_simple_bool_signal(df, "signal").expect("batch parity run");
let stream_gen = RegimeFeaturePAStrategy::new();
let stream_res = run_streaming_simulation(&bars, stream_gen, BacktestConfig::default())
.expect("streaming parity run");
let b_eq = batch_res
.equity_curve
.column("equity")
.unwrap()
.f64()
.unwrap()
.into_iter()
.map(|v| v.unwrap_or(0.0))
.collect::<Vec<_>>();
let s_eq = stream_res
.equity_curve
.column("equity")
.unwrap()
.f64()
.unwrap()
.into_iter()
.map(|v| v.unwrap_or(0.0))
.collect::<Vec<_>>();
assert_eq!(b_eq.len(), s_eq.len(), "equity curve lengths must match");
for (i, (b, s)) in b_eq.iter().zip(s_eq.iter()).enumerate() {
approx::assert_relative_eq!(*b, *s, epsilon = 1e-8, max_relative = 1e-8);
if (b - s).abs() > 1e-7 {
panic!("equity diverged at bar {}: {} vs {}", i, b, s);
}
}
let keys = ["final_equity", "net_pnl", "num_trades"];
for k in keys {
let bv = *batch_res.stats.get(k).unwrap();
let sv = *stream_res.stats.get(k).unwrap();
approx::assert_relative_eq!(bv, sv, epsilon = 1e-6, max_relative = 1e-6);
}
assert_eq!(
batch_res.trades.height(),
stream_res.trades.height(),
"trade counts must match exactly for parity"
);
assert!(
batch_res.trades.height() >= 1,
"parity test strategy must generate >=1 trade on synthetic data"
);
}
}
#[cfg(test)]
mod integration_example_between_epics {
use super::*;
use quantwave_core::features::HurstFeatureExtractor;
#[test]
fn ml_features_feed_backtester_with_metadata() {
let n = 60;
let closes: Vec<f64> = (0..n).map(|i| 100.0 + i as f64 * 0.25).collect();
let timestamps: Vec<i64> = (0..n).map(|i| 1_700_000_000i64 + i as i64).collect();
let mut h_ext = HurstFeatureExtractor::new(15);
let mut exposures = Vec::new();
for &c in &closes {
let f = h_ext.next(c);
let regime_ok = true; let exposure = if regime_ok && f.persistence > 0.52 {
1.0
} else {
0.0
};
exposures.push(exposure);
}
let lf = df![
"timestamp" => timestamps,
"close" => closes,
"exposure" => exposures,
]
.unwrap()
.lazy();
let config = BacktestConfig {
signal_col: "exposure".to_string(),
..Default::default()
};
let result = BacktestEngine::new(config).run(lf).unwrap();
println!(
"Integration smoke test: {} trades produced using ML feature (Hurst) driven exposure",
result.trades.height()
);
assert!(result.equity_curve.height() == n);
}
#[test]
fn test_initial_risk_position_sizer_with_pole_height_and_fraction() {
let sizer = InitialRiskPositionSizer { initial_risk: 0.01, max_target_pct: 0.5 };
let mut meta = HashMap::new();
meta.insert("pole_height_atr".to_string(), 2.0); let sig = StrategySignal { exposure: 1.0, metadata: Some(meta) };
let sized = sizer.compute_sized_exposure(1.0, &sig.metadata, 100.0, 1_000_000.0);
assert!((sized - 5000.0).abs() < 1.0);
let mut meta2 = HashMap::new();
meta2.insert("fraction_at_risk".to_string(), 0.02);
let sig2 = StrategySignal { exposure: 1.0, metadata: Some(meta2) };
let sized2 = sizer.compute_sized_exposure(1.0, &sig2.metadata, 100.0, 1_000_000.0);
assert!((sized2 - 5000.0).abs() < 1.0);
let sig3 = StrategySignal { exposure: 123.0, metadata: None };
let sized3 = sizer.compute_sized_exposure(123.0, &sig3.metadata, 100.0, 1_000_000.0);
assert!((sized3 - 123.0).abs() < 1e-9);
}
}