1mod cross_sectional;
77mod live_bridge;
78mod metrics;
79mod monte_carlo;
80mod sweep;
81mod tearsheet;
82mod walk_forward;
83
84use chrono::{DateTime, Utc};
85use polars::prelude::*;
86pub use cross_sectional::{
87 assign_long_short_exposure, neutralize_factor, run_cross_sectional_backtest, winsorize_factor,
88 zscore_factor, CrossSectionalConfig,
89};
90pub use live_bridge::{
91 LiveBridge, LiveBridgeError, LiveSignalEvent, RecordingLiveBridge,
92};
93pub use metrics::{BacktestReport, PerformanceMetrics};
94pub use tearsheet::{render_tearsheet_html, TearsheetOptions};
95pub use monte_carlo::{
96 monte_carlo_trade_bootstrap, MonteCarloConfig, MonteCarloSummary,
97 monte_carlo_return_paths, MonteCarloReturnConfig, MonteCarloPathSummary,
98};
99pub use sweep::{run_param_sweep, single_param_variants, SweepVariant};
100pub use walk_forward::{run_walk_forward, run_walk_forward_optimize, WalkForwardConfig};
101#[allow(unused_imports)]
102use quantwave_core::traits::Next; use serde::{Deserialize, Serialize};
104use std::collections::HashMap;
105use thiserror::Error;
106
107#[derive(Error, Debug)]
109pub enum BacktestError {
110 #[error("Polars error during simulation: {0}")]
111 Polars(#[from] PolarsError),
112
113 #[error("Invalid input: {0}")]
114 InvalidInput(String),
115
116 #[error("Data must be sorted by timestamp (and symbol for multi-symbol runs)")]
117 UnsortedData,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct CostModel {
123 pub commission_bps: f64,
125 pub slippage_bps: f64,
127 pub initial_cash: f64,
129}
130
131impl Default for CostModel {
132 fn default() -> Self {
133 Self {
134 commission_bps: 5.0, slippage_bps: 2.0, initial_cash: 100_000.0,
137 }
138 }
139}
140
141pub trait CommissionModel: Send + Sync + std::fmt::Debug {
143 fn calculate_commission(&self, fill_quantity: f64, fill_price: f64) -> f64;
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize, Default)]
147pub struct BpsCommissionModel {
148 pub bps: f64,
150}
151
152impl CommissionModel for BpsCommissionModel {
153 fn calculate_commission(&self, fill_quantity: f64, fill_price: f64) -> f64 {
154 (fill_quantity.abs() * fill_price) * (self.bps / 10_000.0)
155 }
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize, Default)]
159pub struct FixedPerShareCommissionModel {
160 pub per_share: f64,
161}
162
163impl CommissionModel for FixedPerShareCommissionModel {
164 fn calculate_commission(&self, fill_quantity: f64, _fill_price: f64) -> f64 {
165 fill_quantity.abs() * self.per_share
166 }
167}
168
169pub trait SlippageModel: Send + Sync + std::fmt::Debug {
171 fn apply(&self, price: f64, quantity: f64, is_buy: bool, adv: Option<f64>) -> f64;
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize, Default)]
175pub struct BpsSlippageModel {
176 pub bps: f64,
177}
178
179impl SlippageModel for BpsSlippageModel {
180 fn apply(&self, price: f64, _quantity: f64, is_buy: bool, _adv: Option<f64>) -> f64 {
181 let s = self.bps / 10_000.0;
182 if is_buy { price * (1.0 + s) } else { price * (1.0 - s) }
183 }
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize, Default)]
187pub struct SquareRootMarketImpactSlippage {
188 pub impact_coef: f64,
189 pub max_participation: f64,
190}
191
192impl SlippageModel for SquareRootMarketImpactSlippage {
193 fn apply(&self, price: f64, quantity: f64, is_buy: bool, adv: Option<f64>) -> f64 {
194 let adv = adv.unwrap_or(1_000_000.0);
195 let part = (quantity.abs() / adv).min(self.max_participation);
196 let impact = self.impact_coef * part.sqrt();
197 if is_buy { price * (1.0 + impact) } else { price * (1.0 - impact) }
198 }
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
206pub struct StopConfig {
207 pub stop_loss_pct: Option<f64>,
209 pub take_profit_pct: Option<f64>,
211 pub trailing_stop_pct: Option<f64>,
213}
214
215impl StopConfig {
216 pub fn has_stops(&self) -> bool {
217 self.stop_loss_pct.is_some()
218 || self.take_profit_pct.is_some()
219 || self.trailing_stop_pct.is_some()
220 }
221}
222
223#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
225pub enum ExecutionDelay {
226 #[default]
228 SameBar,
229 NextBar,
231}
232
233#[derive(Debug, Clone, Serialize, Deserialize)]
235pub enum ExecutionModel {
236 Simple(CostModel),
237 HighFidelity {
238 commission: BpsCommissionModel,
239 slippage: SquareRootMarketImpactSlippage,
240 },
241}
242
243impl Default for ExecutionModel {
244 fn default() -> Self {
245 ExecutionModel::Simple(CostModel::default())
246 }
247}
248
249impl ExecutionModel {
250 pub fn commission_for(&self, qty: f64, px: f64) -> f64 {
251 match self {
252 ExecutionModel::Simple(cm) => (qty.abs() * px) * (cm.commission_bps / 10_000.0),
253 ExecutionModel::HighFidelity { commission, .. } => commission.calculate_commission(qty, px),
254 }
255 }
256 pub fn slippage_price(&self, price: f64, qty: f64, is_buy: bool, adv: Option<f64>) -> f64 {
257 match self {
258 ExecutionModel::Simple(cm) => {
259 let s = cm.slippage_bps / 10_000.0;
260 if is_buy { price * (1.0 + s) } else { price * (1.0 - s) }
261 }
262 ExecutionModel::HighFidelity { slippage, .. } => slippage.apply(price, qty, is_buy, adv),
263 }
264 }
265}
266
267#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct InitialRiskPositionSizer {
273 pub initial_risk: f64,
275 pub max_target_pct: f64,
277}
278
279impl Default for InitialRiskPositionSizer {
280 fn default() -> Self {
281 Self { initial_risk: 0.01, max_target_pct: 0.25 }
282 }
283}
284
285impl InitialRiskPositionSizer {
286 pub fn compute_sized_exposure(
290 &self,
291 raw_exposure: f64,
292 meta: &Option<HashMap<String, f64>>,
293 price: f64,
294 equity: f64,
295 ) -> f64 {
296 let sign = if raw_exposure > 0.0 { 1.0 } else if raw_exposure < 0.0 { -1.0 } else { 0.0 };
297 if let Some(m) = meta {
298 if let Some(frac) = m.get("fraction_at_risk").copied() {
300 if frac > 0.0 {
301 let target_pct = (self.initial_risk / frac).min(self.max_target_pct);
302 let target_units = target_pct * equity / price * sign;
303 return target_units;
304 }
305 }
306 if let Some(pole) = m.get("pole_height_atr").copied() {
308 if pole > 0.0 {
309 let frac = 0.01 / pole;
311 let target_pct = (self.initial_risk / frac).min(self.max_target_pct);
312 let target_units = target_pct * equity / price * sign;
313 return target_units;
314 }
315 }
316 }
317 raw_exposure
318 }
319}
320
321#[derive(Debug, Clone, Serialize, Deserialize)]
323pub struct BacktestConfig {
324 pub cost_model: CostModel,
325 pub timestamp_col: String,
327 pub symbol_col: Option<String>,
328 pub close_col: String,
329 pub signal_col: String,
336 pub entry_filter_col: Option<String>,
340 pub size_multiplier_col: Option<String>,
343
344 pub execution_model: ExecutionModel,
346 pub execution_delay: ExecutionDelay,
348 pub stop_config: StopConfig,
350 pub position_sizer: Option<InitialRiskPositionSizer>,
353}
354
355impl Default for BacktestConfig {
356 fn default() -> Self {
357 Self {
358 cost_model: CostModel::default(),
359 timestamp_col: "timestamp".to_string(),
360 symbol_col: None,
361 close_col: "close".to_string(),
362 signal_col: "signal".to_string(),
363 entry_filter_col: None,
364 size_multiplier_col: None,
365 execution_model: ExecutionModel::default(),
366 execution_delay: ExecutionDelay::default(),
367 stop_config: StopConfig::default(),
368 position_sizer: None,
369 }
370 }
371}
372
373fn signal_bar_index(bar: usize, delay: ExecutionDelay) -> Option<usize> {
375 match delay {
376 ExecutionDelay::SameBar => Some(bar),
377 ExecutionDelay::NextBar => bar.checked_sub(1),
378 }
379}
380
381#[derive(Debug, Clone, Serialize, Deserialize)]
383pub struct Trade {
384 pub trade_id: u32,
385 pub symbol: Option<String>,
386 pub side: i8, pub entry_ts: DateTime<Utc>,
388 pub entry_price: f64,
389 pub entry_fill_price: f64, pub exit_ts: Option<DateTime<Utc>>,
391 pub exit_price: Option<f64>,
392 pub exit_fill_price: Option<f64>,
393 pub pnl_gross: f64,
394 pub costs: f64,
395 pub pnl_net: f64,
396 pub quantity: f64,
399 pub entry_metadata: Option<HashMap<String, f64>>,
402}
403
404#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct EquityPoint {
407 pub ts: DateTime<Utc>,
408 pub symbol: Option<String>, pub equity: f64,
410 pub cash: f64,
411 pub position: f64, pub close: f64,
413}
414
415#[derive(Debug)]
417pub struct BacktestResult {
418 pub trades: DataFrame,
420 pub equity_curve: DataFrame,
422 pub stats: HashMap<String, f64>,
424}
425
426impl BacktestResult {
427 pub fn metrics(&self) -> PerformanceMetrics {
429 PerformanceMetrics::from_result(self)
430 }
431}
432
433#[derive(Debug, Clone)]
436pub struct Bar {
437 pub ts: DateTime<Utc>,
438 pub close: f64,
439}
440
441#[derive(Debug, Clone, Serialize, Deserialize)]
445pub struct StrategySignal {
446 pub exposure: f64,
448 pub metadata: Option<HashMap<String, f64>>,
451}
452
453impl Default for StrategySignal {
454 fn default() -> Self {
455 Self {
456 exposure: 0.0,
457 metadata: None,
458 }
459 }
460}
461
462#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
466pub struct PAEvent {
467 pub long: bool,
469 pub pole_height: Option<f64>,
471 pub strength: Option<f64>,
473}
474
475impl PAEvent {
476 pub fn to_strategy_signal(&self) -> StrategySignal {
478 let mut meta = HashMap::new();
479 if let Some(p) = self.pole_height {
480 meta.insert("pole_height".to_string(), p);
481 }
482 if let Some(s) = self.strength {
483 meta.insert("strength".to_string(), s);
484 }
485 let exposure = if self.long {
486 self.pole_height
487 .map(pole_height_to_exposure)
488 .unwrap_or(1.0)
489 } else {
490 0.0
491 };
492 StrategySignal {
493 exposure,
494 metadata: if meta.is_empty() { None } else { Some(meta) },
495 }
496 }
497}
498
499pub fn pole_height_to_exposure(pole_height: f64) -> f64 {
501 (pole_height / 4.0).clamp(0.4, 2.2)
502}
503
504pub fn parse_struct_signal_row(
512 ca: &StructChunked,
513 i: usize,
514) -> Result<(f64, Option<HashMap<String, f64>>), BacktestError> {
515 let mut meta = HashMap::new();
516
517 let exposure_direct = struct_field_f64(ca, "exposure", i);
518 let long = struct_field_bool(ca, "long", i);
519 let short = struct_field_bool(ca, "short", i);
520
521 if let DataType::Struct(fields) = ca.dtype() {
522 for field in fields {
523 let key = field.name.as_str();
524 if matches!(key, "exposure" | "long" | "short") {
525 continue;
526 }
527 if let Some(v) = struct_field_f64(ca, key, i) {
528 if v.is_finite() {
529 meta.insert(key.to_string(), v);
530 }
531 }
532 }
533 }
534
535 let pole = ["pole_height", "pole_height_atr", "pole_length_atr"]
536 .iter()
537 .find_map(|name| meta.get(*name).copied())
538 .filter(|v| *v > 0.0);
539
540 let exposure = if let Some(e) = exposure_direct {
541 if e.is_finite() && e != 0.0 {
542 e
543 } else if short.unwrap_or(false) {
544 let mag = pole.map(pole_height_to_exposure).unwrap_or(1.0);
545 -mag
546 } else if long.unwrap_or(false) {
547 pole.map(pole_height_to_exposure).unwrap_or(1.0)
548 } else {
549 0.0
550 }
551 } else if short.unwrap_or(false) {
552 let mag = pole.map(pole_height_to_exposure).unwrap_or(1.0);
553 -mag
554 } else if long.unwrap_or(false) {
555 pole.map(pole_height_to_exposure).unwrap_or(1.0)
556 } else {
557 0.0
558 };
559
560 let metadata = if meta.is_empty() { None } else { Some(meta) };
561 Ok((exposure, metadata))
562}
563
564fn struct_field_f64(ca: &StructChunked, name: &str, i: usize) -> Option<f64> {
565 let field = ca.field_by_name(name).ok()?;
566 field.f64().ok().and_then(|arr| arr.get(i))
567}
568
569fn struct_field_bool(ca: &StructChunked, name: &str, i: usize) -> Option<bool> {
570 let field = ca.field_by_name(name).ok()?;
571 field.bool().ok().and_then(|arr| arr.get(i))
572}
573
574pub struct BacktestEngine {
585 config: BacktestConfig,
586}
587
588impl BacktestEngine {
589 pub fn new(config: BacktestConfig) -> Self {
590 Self { config }
591 }
592
593 pub fn with_default_costs() -> Self {
594 Self::new(BacktestConfig::default())
595 }
596
597 pub fn backtest_with_report(&self, lf: LazyFrame) -> Result<BacktestReport, BacktestError> {
599 let result = self.run(lf)?;
600 let metrics = PerformanceMetrics::from_result(&result);
601 Ok(BacktestReport { result, metrics })
602 }
603
604 pub fn run(&self, lf: LazyFrame) -> Result<BacktestResult, BacktestError> {
608 let df = lf.collect()?;
609
610 if df.height() == 0 {
611 return Err(BacktestError::InvalidInput("empty dataframe".into()));
612 }
613
614 let ts_col = &self.config.timestamp_col;
615 let close_col = &self.config.close_col;
616 let sig_col = &self.config.signal_col;
617
618 for c in [ts_col, close_col, sig_col] {
619 if df.column(c).is_err() {
620 return Err(BacktestError::InvalidInput(format!(
621 "missing column: {}",
622 c
623 )));
624 }
625 }
626
627 if self.config.symbol_col.is_some() {
628 return self.run_multi_symbol(df);
629 }
630
631 self.run_single_symbol(df)
632 }
633
634 pub fn run_metrics_only(&self, lf: LazyFrame) -> Result<PerformanceMetrics, BacktestError> {
635 let df = lf.collect()?;
636
637 if df.height() == 0 {
638 return Err(BacktestError::InvalidInput("empty dataframe".into()));
639 }
640
641 let ts_col = &self.config.timestamp_col;
642 let close_col = &self.config.close_col;
643 let sig_col = &self.config.signal_col;
644
645 for c in [ts_col, close_col, sig_col] {
646 if df.column(c).is_err() {
647 return Err(BacktestError::InvalidInput(format!(
648 "missing column: {}",
649 c
650 )));
651 }
652 }
653
654 if self.config.symbol_col.is_some() {
655 return self.run_metrics_multi_symbol(df);
656 }
657
658 self.run_metrics_single_symbol(df)
659 }
660
661 fn run_metrics_single_symbol(&self, df: DataFrame) -> Result<PerformanceMetrics, BacktestError> {
662 let (trades, equity_points) = self.simulate_dataframe(&df, None)?;
663 Ok(PerformanceMetrics::from_raw(&trades, &equity_points, self.per_symbol_initial_cash()))
664 }
665
666 fn run_metrics_multi_symbol(&self, df: DataFrame) -> Result<PerformanceMetrics, BacktestError> {
667 let sym_col = self
668 .config
669 .symbol_col
670 .as_ref()
671 .expect("symbol_col set");
672
673 if df.column(sym_col).is_err() {
674 return Err(BacktestError::InvalidInput(format!(
675 "missing column: {}",
676 sym_col
677 )));
678 }
679
680 let ts_series = df.column(&self.config.timestamp_col)?.clone();
681 let timestamps = self.extract_timestamps(&ts_series)?;
682 let symbols = extract_string_column(df.column(sym_col)?.clone())?;
683 validate_sorted_timestamp_symbol(×tamps, &symbols)?;
684
685 let mut unique_symbols: Vec<String> = Vec::new();
686 let mut seen = std::collections::HashSet::new();
687 for s in &symbols {
688 if seen.insert(s.clone()) {
689 unique_symbols.push(s.clone());
690 }
691 }
692
693 let mut all_trades: Vec<Trade> = Vec::new();
694 let mut per_symbol_equity: HashMap<String, Vec<EquityPoint>> = HashMap::new();
695
696 for symbol in &unique_symbols {
697 let sub = df
698 .clone()
699 .lazy()
700 .filter(col(sym_col).eq(lit(symbol.as_str())))
701 .sort(
702 [&self.config.timestamp_col],
703 SortMultipleOptions::default(),
704 )
705 .collect()?;
706
707 let (mut trades, equity_points) = self.simulate_dataframe(&sub, Some(symbol))?;
708 all_trades.append(&mut trades);
709 per_symbol_equity.insert(symbol.clone(), equity_points);
710 }
711
712 let portfolio_equity = aggregate_portfolio_equity(&per_symbol_equity);
713 let n_symbols = unique_symbols.len() as f64;
714 let portfolio_initial = self.per_symbol_initial_cash() * n_symbols;
715 Ok(PerformanceMetrics::from_raw(&all_trades, &portfolio_equity, portfolio_initial))
716 }
717
718 fn run_single_symbol(&self, df: DataFrame) -> Result<BacktestResult, BacktestError> {
719 let (trades, equity_points) = self.simulate_dataframe(&df, None)?;
720
721 let initial_cash = self.per_symbol_initial_cash();
722 let final_equity = equity_points
723 .last()
724 .map(|e| e.equity)
725 .unwrap_or(initial_cash);
726 let total_return = (final_equity - initial_cash) / initial_cash;
727 let num_trades = trades.len() as f64;
728
729 let mut stats = HashMap::new();
730 stats.insert("initial_cash".to_string(), initial_cash);
731 stats.insert("final_equity".to_string(), final_equity);
732 stats.insert("total_return".to_string(), total_return);
733 stats.insert("num_trades".to_string(), num_trades);
734 stats.insert("net_pnl".to_string(), final_equity - initial_cash);
735
736 Ok(BacktestResult {
737 trades: self.trades_to_df(&trades, false)?,
738 equity_curve: self.equity_to_df(&equity_points, false)?,
739 stats,
740 })
741 }
742
743 fn run_multi_symbol(&self, df: DataFrame) -> Result<BacktestResult, BacktestError> {
744 let sym_col = self
745 .config
746 .symbol_col
747 .as_ref()
748 .expect("symbol_col set");
749
750 if df.column(sym_col).is_err() {
751 return Err(BacktestError::InvalidInput(format!(
752 "missing column: {}",
753 sym_col
754 )));
755 }
756
757 let ts_series = df.column(&self.config.timestamp_col)?.clone();
758 let timestamps = self.extract_timestamps(&ts_series)?;
759 let symbols = extract_string_column(df.column(sym_col)?.clone())?;
760 validate_sorted_timestamp_symbol(×tamps, &symbols)?;
761
762 let mut unique_symbols: Vec<String> = Vec::new();
763 let mut seen = std::collections::HashSet::new();
764 for s in &symbols {
765 if seen.insert(s.clone()) {
766 unique_symbols.push(s.clone());
767 }
768 }
769
770 let per_symbol_initial = self.per_symbol_initial_cash();
771 let mut all_trades: Vec<Trade> = Vec::new();
772 let mut per_symbol_equity: HashMap<String, Vec<EquityPoint>> = HashMap::new();
773
774 for symbol in &unique_symbols {
775 let sub = df
776 .clone()
777 .lazy()
778 .filter(col(sym_col).eq(lit(symbol.as_str())))
779 .sort(
780 [&self.config.timestamp_col],
781 SortMultipleOptions::default(),
782 )
783 .collect()?;
784
785 let (mut trades, equity_points) = self.simulate_dataframe(&sub, Some(symbol))?;
786 all_trades.append(&mut trades);
787 per_symbol_equity.insert(symbol.clone(), equity_points);
788 }
789
790 let portfolio_equity = aggregate_portfolio_equity(&per_symbol_equity);
791 let mut combined_equity: Vec<EquityPoint> = per_symbol_equity
792 .values()
793 .flatten()
794 .cloned()
795 .collect();
796 combined_equity.extend(portfolio_equity.clone());
797
798 let n_symbols = unique_symbols.len() as f64;
799 let portfolio_initial = per_symbol_initial * n_symbols;
800 let portfolio_final = portfolio_equity
801 .last()
802 .map(|e| e.equity)
803 .unwrap_or(portfolio_initial);
804 let total_return = (portfolio_final - portfolio_initial) / portfolio_initial;
805 let num_trades = all_trades.len() as f64;
806
807 let mut stats = HashMap::new();
808 stats.insert("initial_cash".to_string(), portfolio_initial);
809 stats.insert("final_equity".to_string(), portfolio_final);
810 stats.insert("total_return".to_string(), total_return);
811 stats.insert("num_trades".to_string(), num_trades);
812 stats.insert("net_pnl".to_string(), portfolio_final - portfolio_initial);
813 stats.insert("num_symbols".to_string(), n_symbols);
814
815 Ok(BacktestResult {
816 trades: self.trades_to_df(&all_trades, true)?,
817 equity_curve: self.equity_to_df(&combined_equity, true)?,
818 stats,
819 })
820 }
821
822 fn per_symbol_initial_cash(&self) -> f64 {
823 match &self.config.execution_model {
824 ExecutionModel::Simple(cm) => cm.initial_cash,
825 _ => 100_000.0,
826 }
827 }
828
829 fn simulate_dataframe(
830 &self,
831 df: &DataFrame,
832 symbol: Option<&str>,
833 ) -> Result<(Vec<Trade>, Vec<EquityPoint>), BacktestError> {
834 let ts_col = &self.config.timestamp_col;
835 let close_col = &self.config.close_col;
836 let sig_col = &self.config.signal_col;
837
838 let ts_series = df.column(ts_col)?.clone();
839 let close_ca = df.column(close_col)?.f64()?.clone();
840 let (signal_vals, signal_metas) = self.load_signals(df, sig_col)?;
841
842 let entry_filters = self.load_entry_filters(df)?;
843 let size_multipliers = self.load_size_multipliers(df)?;
844
845 let n = signal_vals.len();
846 if let Some(ref f) = entry_filters {
847 if f.len() != n {
848 return Err(BacktestError::InvalidInput(
849 "entry_filter column length mismatch".into(),
850 ));
851 }
852 }
853 if let Some(ref m) = size_multipliers {
854 if m.len() != n {
855 return Err(BacktestError::InvalidInput(
856 "size_multiplier column length mismatch".into(),
857 ));
858 }
859 }
860
861 let effective_signals: Vec<f64> = signal_vals
862 .iter()
863 .enumerate()
864 .map(|(i, &raw)| {
865 apply_signal_modifiers(
866 raw,
867 entry_filters.as_ref().map(|f| f[i]),
868 size_multipliers.as_ref().map(|m| m[i]),
869 )
870 })
871 .collect();
872
873 let timestamps = self.extract_timestamps(&ts_series)?;
874 let closes: Vec<f64> = close_ca
875 .into_iter()
876 .map(|v| v.unwrap_or(f64::NAN))
877 .collect();
878
879 if timestamps.len() != closes.len() || closes.len() != effective_signals.len() {
880 return Err(BacktestError::InvalidInput("column length mismatch".into()));
881 }
882
883 let exec = &self.config.execution_model;
884 let sizer = &self.config.position_sizer;
885 let mut effective_metas: Vec<Option<HashMap<String, f64>>> =
886 Vec::with_capacity(effective_signals.len());
887 for (i, &raw) in effective_signals.iter().enumerate() {
888 if raw == 0.0 {
889 effective_metas.push(None);
890 } else {
891 effective_metas.push(signal_metas.get(i).cloned().flatten());
892 }
893 }
894 let delay = self.config.execution_delay;
895 let stops = &self.config.stop_config;
896 let (mut trades, mut equity_points) = run_simulation(
897 ×tamps,
898 &closes,
899 |i| (effective_signals[i], effective_metas[i].clone()),
900 exec,
901 sizer,
902 delay,
903 stops,
904 );
905
906 if let Some(sym) = symbol {
907 let sym_owned = sym.to_string();
908 for t in &mut trades {
909 t.symbol = Some(sym_owned.clone());
910 }
911 for e in &mut equity_points {
912 e.symbol = Some(sym_owned.clone());
913 }
914 }
915
916 Ok((trades, equity_points))
917 }
918
919 fn load_signals(
920 &self,
921 df: &DataFrame,
922 sig_col: &str,
923 ) -> Result<(Vec<f64>, Vec<Option<HashMap<String, f64>>>), BacktestError> {
924 let signal_series = df.column(sig_col)?;
925 let s = signal_series
926 .as_series()
927 .ok_or_else(|| BacktestError::InvalidInput("column has no series backing".into()))?;
928
929 if s.dtype().is_struct() {
930 let ca = s.struct_().map_err(|e| BacktestError::Polars(e))?;
931 let n = ca.len();
932 let mut exposures = Vec::with_capacity(n);
933 let mut metas = Vec::with_capacity(n);
934 for i in 0..n {
935 let (exp, meta) = parse_struct_signal_row(ca, i)?;
936 exposures.push(exp);
937 metas.push(meta);
938 }
939 return Ok((exposures, metas));
940 }
941
942 let signal_vals: Vec<f64> = if signal_series.dtype().is_bool() {
943 signal_series
944 .bool()?
945 .into_iter()
946 .map(|b| if b.unwrap_or(false) { 1.0 } else { 0.0 })
947 .collect()
948 } else {
949 signal_series
950 .f64()?
951 .into_iter()
952 .map(|v| v.unwrap_or(0.0))
953 .collect()
954 };
955 let metas = vec![None; signal_vals.len()];
956 Ok((signal_vals, metas))
957 }
958
959 fn load_entry_filters(&self, df: &DataFrame) -> Result<Option<Vec<bool>>, BacktestError> {
960 let Some(col_name) = &self.config.entry_filter_col else {
961 return Ok(None);
962 };
963 if df.column(col_name).is_err() {
964 return Err(BacktestError::InvalidInput(format!(
965 "missing column: {}",
966 col_name
967 )));
968 }
969 extract_bool_column(df.column(col_name)?.clone())
970 .map(Some)
971 }
972
973 fn load_size_multipliers(&self, df: &DataFrame) -> Result<Option<Vec<f64>>, BacktestError> {
974 let Some(col_name) = &self.config.size_multiplier_col else {
975 return Ok(None);
976 };
977 if df.column(col_name).is_err() {
978 return Err(BacktestError::InvalidInput(format!(
979 "missing column: {}",
980 col_name
981 )));
982 }
983 extract_f64_column(df.column(col_name)?.clone())
984 .map(Some)
985 }
986
987 fn extract_timestamps(&self, col: &Column) -> Result<Vec<DateTime<Utc>>, BacktestError> {
988 let s = col
991 .as_series()
992 .ok_or_else(|| BacktestError::InvalidInput("column has no series backing".into()))?;
993
994 if let Ok(ca) = s.datetime() {
996 return Ok(ca
997 .into_iter()
998 .map(|opt| {
999 opt.map(|v| {
1000 let secs = v / 1000;
1002 let nanos = ((v % 1000) * 1_000_000) as u32;
1003 DateTime::<Utc>::from_timestamp(secs, nanos).unwrap_or_else(Utc::now)
1004 })
1005 .unwrap_or_else(Utc::now)
1006 })
1007 .collect());
1008 }
1009
1010 if let Ok(ca) = s.i64() {
1011 return Ok(ca
1013 .into_iter()
1014 .enumerate()
1015 .map(|(i, opt)| {
1016 let v = opt.unwrap_or(i as i64);
1017 DateTime::<Utc>::from_timestamp(v, 0).unwrap_or_else(Utc::now)
1018 })
1019 .collect());
1020 }
1021
1022 Err(BacktestError::InvalidInput(
1024 "timestamp column must be Datetime or Int64 for this MVP".into(),
1025 ))
1026 }
1027
1028 fn trades_to_df(&self, trades: &[Trade], include_symbol: bool) -> Result<DataFrame, PolarsError> {
1029 if trades.is_empty() {
1030 let mut cols = vec![
1031 Column::new("trade_id".into(), Vec::<u32>::new()),
1032 Column::new("side".into(), Vec::<i8>::new()),
1033 Column::new("entry_ts".into(), Vec::<i64>::new()),
1034 Column::new("entry_price".into(), Vec::<f64>::new()),
1035 Column::new("pnl_net".into(), Vec::<f64>::new()),
1036 ];
1037 if include_symbol {
1038 cols.push(Column::new("symbol".into(), Vec::<Option<String>>::new()));
1039 }
1040 return Ok(DataFrame::new(cols)?);
1041 }
1042
1043 let ids: Vec<u32> = trades.iter().map(|t| t.trade_id).collect();
1044 let sides: Vec<i8> = trades.iter().map(|t| t.side).collect();
1045 let entry_ts: Vec<i64> = trades.iter().map(|t| t.entry_ts.timestamp()).collect();
1046 let entry_px: Vec<f64> = trades.iter().map(|t| t.entry_price).collect();
1047 let exit_ts: Vec<Option<i64>> = trades
1048 .iter()
1049 .map(|t| t.exit_ts.map(|d| d.timestamp()))
1050 .collect();
1051 let exit_px: Vec<Option<f64>> = trades.iter().map(|t| t.exit_price).collect();
1052 let qty: Vec<f64> = trades.iter().map(|t| t.quantity).collect();
1053 let pnl: Vec<f64> = trades.iter().map(|t| t.pnl_net).collect();
1054
1055 let mut cols = vec![
1056 Column::new("trade_id".into(), ids),
1057 Column::new("side".into(), sides),
1058 Column::new("entry_ts".into(), entry_ts),
1059 Column::new("entry_price".into(), entry_px),
1060 Column::new("exit_ts".into(), exit_ts),
1061 Column::new("exit_price".into(), exit_px),
1062 Column::new("quantity".into(), qty),
1063 Column::new("pnl_net".into(), pnl),
1064 ];
1065 if include_symbol {
1066 let symbols: Vec<Option<String>> = trades.iter().map(|t| t.symbol.clone()).collect();
1067 cols.push(Column::new("symbol".into(), symbols));
1068 }
1069
1070 DataFrame::new(cols)
1071 }
1072
1073 fn equity_to_df(&self, points: &[EquityPoint], include_symbol: bool) -> Result<DataFrame, PolarsError> {
1074 if points.is_empty() {
1075 let mut cols = vec![
1076 Column::new("ts".into(), Vec::<i64>::new()),
1077 Column::new("equity".into(), Vec::<f64>::new()),
1078 Column::new("position".into(), Vec::<f64>::new()),
1079 ];
1080 if include_symbol {
1081 cols.push(Column::new("symbol".into(), Vec::<Option<String>>::new()));
1082 }
1083 return Ok(DataFrame::new(cols)?);
1084 }
1085
1086 let ts: Vec<i64> = points.iter().map(|p| p.ts.timestamp()).collect();
1087 let eq: Vec<f64> = points.iter().map(|p| p.equity).collect();
1088 let pos: Vec<f64> = points.iter().map(|p| p.position).collect();
1089 let cash: Vec<f64> = points.iter().map(|p| p.cash).collect();
1090 let close: Vec<f64> = points.iter().map(|p| p.close).collect();
1091
1092 let mut cols = vec![
1093 Column::new("ts".into(), ts),
1094 Column::new("equity".into(), eq),
1095 Column::new("cash".into(), cash),
1096 Column::new("position".into(), pos),
1097 Column::new("close".into(), close),
1098 ];
1099 if include_symbol {
1100 let symbols: Vec<Option<String>> = points.iter().map(|p| p.symbol.clone()).collect();
1101 cols.push(Column::new("symbol".into(), symbols));
1102 }
1103
1104 DataFrame::new(cols)
1105 }
1106}
1107
1108pub fn apply_signal_modifiers(
1111 raw_signal: f64,
1112 entry_filter: Option<bool>,
1113 size_multiplier: Option<f64>,
1114) -> f64 {
1115 if matches!(entry_filter, Some(false)) {
1116 return 0.0;
1117 }
1118 let mut exposure = raw_signal;
1119 if let Some(m) = size_multiplier {
1120 exposure *= m;
1121 }
1122 if exposure.is_finite() && exposure != 0.0 {
1123 exposure
1124 } else {
1125 0.0
1126 }
1127}
1128
1129fn extract_bool_column(col: Column) -> Result<Vec<bool>, BacktestError> {
1130 let s = col
1131 .as_series()
1132 .ok_or_else(|| BacktestError::InvalidInput("column has no series backing".into()))?;
1133 if let Ok(ca) = s.bool() {
1134 return Ok(ca
1135 .into_iter()
1136 .map(|opt| opt.unwrap_or(false))
1137 .collect());
1138 }
1139 Err(BacktestError::InvalidInput(
1140 "entry_filter column must be boolean".into(),
1141 ))
1142}
1143
1144fn extract_f64_column(col: Column) -> Result<Vec<f64>, BacktestError> {
1145 let s = col
1146 .as_series()
1147 .ok_or_else(|| BacktestError::InvalidInput("column has no series backing".into()))?;
1148 if let Ok(ca) = s.f64() {
1149 return Ok(ca.into_iter().map(|opt| opt.unwrap_or(0.0)).collect());
1150 }
1151 Err(BacktestError::InvalidInput(
1152 "size_multiplier column must be f64".into(),
1153 ))
1154}
1155
1156fn extract_string_column(col: Column) -> Result<Vec<String>, BacktestError> {
1157 let s = col
1158 .as_series()
1159 .ok_or_else(|| BacktestError::InvalidInput("column has no series backing".into()))?;
1160 if let Ok(ca) = s.str() {
1161 return Ok(ca
1162 .into_iter()
1163 .map(|opt| opt.unwrap_or_default().to_string())
1164 .collect());
1165 }
1166 Err(BacktestError::InvalidInput(
1167 "symbol column must be Utf8/String".into(),
1168 ))
1169}
1170
1171fn validate_sorted_timestamp_symbol(
1172 timestamps: &[DateTime<Utc>],
1173 symbols: &[String],
1174) -> Result<(), BacktestError> {
1175 if timestamps.len() != symbols.len() {
1176 return Err(BacktestError::InvalidInput("column length mismatch".into()));
1177 }
1178 for i in 1..timestamps.len() {
1179 let prev = (×tamps[i - 1], &symbols[i - 1]);
1180 let curr = (×tamps[i], &symbols[i]);
1181 if curr < prev {
1182 return Err(BacktestError::UnsortedData);
1183 }
1184 }
1185 Ok(())
1186}
1187
1188fn aggregate_portfolio_equity(per_symbol: &HashMap<String, Vec<EquityPoint>>) -> Vec<EquityPoint> {
1189 use std::collections::BTreeSet;
1190
1191 let mut ts_set = BTreeSet::new();
1192 for points in per_symbol.values() {
1193 for p in points {
1194 ts_set.insert(p.ts);
1195 }
1196 }
1197
1198 ts_set
1199 .into_iter()
1200 .map(|ts| {
1201 let mut total_equity = 0.0;
1202 let mut total_cash = 0.0;
1203 let mut total_position = 0.0;
1204 for points in per_symbol.values() {
1205 if let Some(p) = points.iter().find(|p| p.ts == ts) {
1206 total_equity += p.equity;
1207 total_cash += p.cash;
1208 total_position += p.position;
1209 }
1210 }
1211 EquityPoint {
1212 ts,
1213 symbol: None,
1214 equity: total_equity,
1215 cash: total_cash,
1216 position: total_position,
1217 close: 0.0,
1218 }
1219 })
1220 .collect()
1221}
1222
1223pub fn backtest_simple_bool_signal(
1226 ohlcv: DataFrame,
1227 signal_col: &str,
1228) -> Result<BacktestResult, BacktestError> {
1229 let config = BacktestConfig {
1230 signal_col: signal_col.to_string(),
1231 ..Default::default()
1232 };
1233 let engine = BacktestEngine::new(config);
1234 engine.run(ohlcv.lazy())
1235}
1236
1237fn run_simulation(
1245 timestamps: &[DateTime<Utc>],
1246 closes: &[f64],
1247 mut next_signal: impl FnMut(usize) -> (f64, Option<HashMap<String, f64>>),
1248 exec: &ExecutionModel,
1249 sizer: &Option<InitialRiskPositionSizer>,
1250 execution_delay: ExecutionDelay,
1251 stop_config: &StopConfig,
1252) -> (Vec<Trade>, Vec<EquityPoint>) {
1253 let mut cash = match exec {
1254 ExecutionModel::Simple(cm) => cm.initial_cash,
1255 ExecutionModel::HighFidelity { .. } => 100_000.0,
1256 };
1257 let mut current_exposure: f64 = 0.0;
1258 let mut entry_price: f64 = 0.0;
1259 let mut entry_ts: Option<DateTime<Utc>> = None;
1260 let mut entry_metadata: Option<HashMap<String, f64>> = None;
1261 let mut trailing_stop_level: Option<f64> = None;
1262 let mut need_signal_reset = false;
1263 let mut trade_id: u32 = 0;
1264 let mut trades: Vec<Trade> = Vec::new();
1265 let mut equity_points: Vec<EquityPoint> = Vec::with_capacity(closes.len());
1266
1267 let mut record_position_exit =
1268 |cash: &mut f64,
1269 tid: u32,
1270 side: i8,
1271 qty: f64,
1272 entry_px: f64,
1273 ets: DateTime<Utc>,
1274 exit_bar: usize,
1275 meta: Option<HashMap<String, f64>>| {
1276 let close = closes[exit_bar];
1277 let is_buy = side == -1;
1279 let fill_price = exec.slippage_price(close, qty, is_buy, None);
1280 let notional = fill_price * qty;
1281 let cost = exec.commission_for(qty, fill_price);
1282 let gross_pnl = if side == 1 {
1283 (fill_price - entry_px) * qty
1284 } else {
1285 (entry_px - fill_price) * qty
1286 };
1287 let net_pnl = gross_pnl - cost;
1288 if side == 1 {
1289 *cash += notional - cost;
1290 } else {
1291 *cash -= notional + cost;
1292 }
1293 trades.push(Trade {
1294 trade_id: tid,
1295 symbol: None,
1296 side,
1297 entry_ts: ets,
1298 entry_price: entry_px,
1299 entry_fill_price: entry_px,
1300 exit_ts: Some(timestamps[exit_bar]),
1301 exit_price: Some(close),
1302 exit_fill_price: Some(fill_price),
1303 pnl_gross: gross_pnl,
1304 costs: cost,
1305 pnl_net: net_pnl,
1306 quantity: qty,
1307 entry_metadata: meta,
1308 });
1309 };
1310
1311 let open_position = |cash: &mut f64,
1312 tid: u32,
1313 desired: f64,
1314 fill_bar: usize,
1315 meta: Option<HashMap<String, f64>>|
1316 -> (u32, f64, f64, Option<DateTime<Utc>>, Option<HashMap<String, f64>>, Option<f64>) {
1317 let qty = desired.abs();
1318 let is_long = desired > 0.0;
1319 let is_buy = is_long;
1320 let close = closes[fill_bar];
1321 let fill_price = exec.slippage_price(close, qty, is_buy, None);
1322 let notional = fill_price * qty;
1323 let cost = exec.commission_for(qty, fill_price);
1324 if is_long {
1325 *cash -= notional + cost;
1326 } else {
1327 *cash += notional - cost;
1328 }
1329 let new_tid = tid + 1;
1330 let exposure = if is_long { qty } else { -qty };
1331 let trail = stop_config.trailing_stop_pct.map(|pct| {
1332 if is_long {
1333 fill_price * (1.0 - pct)
1334 } else {
1335 fill_price * (1.0 + pct)
1336 }
1337 });
1338 (
1339 new_tid,
1340 exposure,
1341 fill_price,
1342 Some(timestamps[fill_bar]),
1343 meta,
1344 trail,
1345 )
1346 };
1347
1348 for i in 0..closes.len() {
1349 let close = closes[i];
1350 if !close.is_finite() {
1351 let equity = cash + current_exposure * close;
1352 equity_points.push(EquityPoint {
1353 ts: timestamps[i],
1354 symbol: None,
1355 equity,
1356 cash,
1357 position: current_exposure,
1358 close,
1359 });
1360 continue;
1361 }
1362
1363 if current_exposure != 0.0 && stop_config.has_stops() {
1365 let is_long = current_exposure > 0.0;
1366 let qty = current_exposure.abs();
1367
1368 if let Some(trail_pct) = stop_config.trailing_stop_pct {
1369 if is_long {
1370 let new_level = close * (1.0 - trail_pct);
1371 trailing_stop_level = Some(match trailing_stop_level {
1372 Some(prev) => prev.max(new_level),
1373 None => new_level,
1374 });
1375 } else {
1376 let new_level = close * (1.0 + trail_pct);
1377 trailing_stop_level = Some(match trailing_stop_level {
1378 Some(prev) => prev.min(new_level),
1379 None => new_level,
1380 });
1381 }
1382 }
1383
1384 let mut stop_out = false;
1385 if is_long {
1386 if let Some(tp) = stop_config.take_profit_pct {
1387 if close >= entry_price * (1.0 + tp) {
1388 stop_out = true;
1389 }
1390 }
1391 if !stop_out {
1392 let mut effective_stop = f64::NEG_INFINITY;
1393 if let Some(sl) = stop_config.stop_loss_pct {
1394 effective_stop = entry_price * (1.0 - sl);
1395 }
1396 if let Some(level) = trailing_stop_level {
1397 effective_stop = effective_stop.max(level);
1398 }
1399 if effective_stop > f64::NEG_INFINITY && close <= effective_stop {
1400 stop_out = true;
1401 }
1402 }
1403 } else {
1404 if let Some(tp) = stop_config.take_profit_pct {
1405 if close <= entry_price * (1.0 - tp) {
1406 stop_out = true;
1407 }
1408 }
1409 if !stop_out {
1410 let mut effective_stop = f64::INFINITY;
1411 if let Some(sl) = stop_config.stop_loss_pct {
1412 effective_stop = entry_price * (1.0 + sl);
1413 }
1414 if let Some(level) = trailing_stop_level {
1415 effective_stop = effective_stop.min(level);
1416 }
1417 if effective_stop < f64::INFINITY && close >= effective_stop {
1418 stop_out = true;
1419 }
1420 }
1421 }
1422
1423 if stop_out {
1424 if let Some(ets) = entry_ts.take() {
1425 let side = if is_long { 1 } else { -1 };
1426 record_position_exit(
1427 &mut cash,
1428 trade_id,
1429 side,
1430 qty,
1431 entry_price,
1432 ets,
1433 i,
1434 entry_metadata.clone(),
1435 );
1436 current_exposure = 0.0;
1437 entry_price = 0.0;
1438 trailing_stop_level = None;
1439 entry_metadata = None;
1440 need_signal_reset = true;
1441 }
1442 }
1443 }
1444
1445 let (raw_exposure, meta) = match signal_bar_index(i, execution_delay) {
1446 Some(si) => next_signal(si),
1447 None => (0.0, None),
1448 };
1449 let current_equity = cash + current_exposure * close;
1451 let desired_exposure = if let Some(s) = sizer {
1452 s.compute_sized_exposure(raw_exposure, &meta, close, current_equity)
1453 } else {
1454 raw_exposure
1455 };
1456 let desired = if desired_exposure.is_finite() && desired_exposure != 0.0 {
1457 desired_exposure
1458 } else {
1459 0.0
1460 };
1461
1462 if desired == 0.0 {
1463 need_signal_reset = false;
1464 }
1465
1466 let currently_in = current_exposure != 0.0;
1467
1468 if desired == 0.0 && currently_in {
1469 if let Some(ets) = entry_ts.take() {
1470 let side = if current_exposure > 0.0 { 1 } else { -1 };
1471 record_position_exit(
1472 &mut cash,
1473 trade_id,
1474 side,
1475 current_exposure.abs(),
1476 entry_price,
1477 ets,
1478 i,
1479 meta.clone(),
1480 );
1481 current_exposure = 0.0;
1482 entry_price = 0.0;
1483 trailing_stop_level = None;
1484 entry_metadata = None;
1485 }
1486 } else if desired != 0.0 && !need_signal_reset {
1487 let want_long = desired > 0.0;
1488 let in_long = current_exposure > 0.0;
1489 let in_short = current_exposure < 0.0;
1490 let flip = (want_long && in_short) || (!want_long && in_long);
1491
1492 if flip {
1493 if let Some(ets) = entry_ts.take() {
1494 let side = if in_long { 1 } else { -1 };
1495 record_position_exit(
1496 &mut cash,
1497 trade_id,
1498 side,
1499 current_exposure.abs(),
1500 entry_price,
1501 ets,
1502 i,
1503 entry_metadata.clone(),
1504 );
1505 current_exposure = 0.0;
1506 entry_price = 0.0;
1507 trailing_stop_level = None;
1508 entry_metadata = None;
1509 }
1510 }
1511
1512 if current_exposure == 0.0 {
1513 let (new_tid, exp, ep, ets, em, trail) =
1514 open_position(&mut cash, trade_id, desired, i, meta.clone());
1515 trade_id = new_tid;
1516 current_exposure = exp;
1517 entry_price = ep;
1518 entry_ts = ets;
1519 entry_metadata = em;
1520 trailing_stop_level = trail;
1521 }
1522 }
1523
1524 let equity = cash + current_exposure * close;
1525 equity_points.push(EquityPoint {
1526 ts: timestamps[i],
1527 symbol: None,
1528 equity,
1529 cash,
1530 position: current_exposure,
1531 close,
1532 });
1533 }
1534
1535 if current_exposure != 0.0 {
1537 let last_close = *closes.last().unwrap();
1538 let qty = current_exposure.abs();
1539 let side = if current_exposure > 0.0 { 1 } else { -1 };
1540 let gross = if side == 1 {
1541 (last_close - entry_price) * qty
1542 } else {
1543 (entry_price - last_close) * qty
1544 };
1545 if let Some(ets) = entry_ts {
1546 trades.push(Trade {
1547 trade_id,
1548 symbol: None,
1549 side,
1550 entry_ts: ets,
1551 entry_price,
1552 entry_fill_price: entry_price,
1553 exit_ts: None,
1554 exit_price: Some(last_close),
1555 exit_fill_price: None,
1556 pnl_gross: gross,
1557 costs: 0.0,
1558 pnl_net: gross,
1559 quantity: qty,
1560 entry_metadata: None,
1561 });
1562 }
1563 }
1564
1565 (trades, equity_points)
1566}
1567
1568pub fn run_streaming_simulation<G>(
1575 bars: &[Bar],
1576 mut generator: G,
1577 config: BacktestConfig,
1578) -> Result<BacktestResult, BacktestError>
1579where
1580 G: for<'a> Next<&'a Bar, Output = StrategySignal>,
1581{
1582 if bars.is_empty() {
1583 return Err(BacktestError::InvalidInput("empty bars".into()));
1584 }
1585
1586 let timestamps: Vec<DateTime<Utc>> = bars.iter().map(|b| b.ts).collect();
1587 let closes: Vec<f64> = bars.iter().map(|b| b.close).collect();
1588
1589 let exec = &config.execution_model;
1590 let sizer = &config.position_sizer;
1591
1592 let delay = config.execution_delay;
1593 let stops = &config.stop_config;
1594 let (trades, equity_points) = run_simulation(
1595 ×tamps,
1596 &closes,
1597 |i| {
1598 let sig = generator.next(&bars[i]);
1599 (sig.exposure, sig.metadata.clone())
1600 },
1601 exec,
1602 sizer,
1603 delay,
1604 stops,
1605 );
1606
1607 let trades_df = if trades.is_empty() {
1612 DataFrame::new(vec![
1613 Column::new("trade_id".into(), Vec::<u32>::new()),
1614 Column::new("side".into(), Vec::<i8>::new()),
1615 Column::new("entry_ts".into(), Vec::<i64>::new()),
1616 Column::new("entry_price".into(), Vec::<f64>::new()),
1617 Column::new("pnl_net".into(), Vec::<f64>::new()),
1618 ])?
1619 } else {
1620 let ids: Vec<u32> = trades.iter().map(|t| t.trade_id).collect();
1621 let sides: Vec<i8> = trades.iter().map(|t| t.side).collect();
1622 let entry_ts: Vec<i64> = trades.iter().map(|t| t.entry_ts.timestamp()).collect();
1623 let entry_px: Vec<f64> = trades.iter().map(|t| t.entry_price).collect();
1624 let exit_ts: Vec<Option<i64>> = trades
1625 .iter()
1626 .map(|t| t.exit_ts.map(|d| d.timestamp()))
1627 .collect();
1628 let exit_px: Vec<Option<f64>> = trades.iter().map(|t| t.exit_price).collect();
1629 let pnl: Vec<f64> = trades.iter().map(|t| t.pnl_net).collect();
1630
1631 DataFrame::new(vec![
1632 Column::new("trade_id".into(), ids),
1633 Column::new("side".into(), sides),
1634 Column::new("entry_ts".into(), entry_ts),
1635 Column::new("entry_price".into(), entry_px),
1636 Column::new("exit_ts".into(), exit_ts),
1637 Column::new("exit_price".into(), exit_px),
1638 Column::new("pnl_net".into(), pnl),
1639 ])?
1640 };
1641
1642 let equity_df = if equity_points.is_empty() {
1643 DataFrame::new(vec![
1644 Column::new("ts".into(), Vec::<i64>::new()),
1645 Column::new("equity".into(), Vec::<f64>::new()),
1646 Column::new("position".into(), Vec::<f64>::new()),
1647 ])?
1648 } else {
1649 let ts: Vec<i64> = equity_points.iter().map(|p| p.ts.timestamp()).collect();
1650 let eq: Vec<f64> = equity_points.iter().map(|p| p.equity).collect();
1651 let pos: Vec<f64> = equity_points.iter().map(|p| p.position).collect();
1652 let cash: Vec<f64> = equity_points.iter().map(|p| p.cash).collect();
1653 let close: Vec<f64> = equity_points.iter().map(|p| p.close).collect();
1654
1655 DataFrame::new(vec![
1656 Column::new("ts".into(), ts),
1657 Column::new("equity".into(), eq),
1658 Column::new("cash".into(), cash),
1659 Column::new("position".into(), pos),
1660 Column::new("close".into(), close),
1661 ])?
1662 };
1663
1664 let initial_cash = match &config.execution_model {
1665 ExecutionModel::Simple(cm) => cm.initial_cash,
1666 _ => 100_000.0,
1667 };
1668 let final_equity = equity_points
1669 .last()
1670 .map(|e| e.equity)
1671 .unwrap_or(initial_cash);
1672 let total_return = (final_equity - initial_cash) / initial_cash;
1673 let num_trades = trades.len() as f64;
1674
1675 let mut stats = HashMap::new();
1676 stats.insert("initial_cash".to_string(), initial_cash);
1677 stats.insert("final_equity".to_string(), final_equity);
1678 stats.insert("total_return".to_string(), total_return);
1679 stats.insert("num_trades".to_string(), num_trades);
1680 stats.insert("net_pnl".to_string(), final_equity - initial_cash);
1681
1682 Ok(BacktestResult {
1683 trades: trades_df,
1684 equity_curve: equity_df,
1685 stats,
1686 })
1687}
1688
1689#[cfg(test)]
1690mod tests {
1691 use super::*;
1692 use approx::assert_relative_eq;
1693 use rand::Rng;
1695 use quantwave_core::features::CyberCycleFeatureExtractor;
1697 use quantwave_core::regimes::MarketRegime;
1698 use quantwave_core::regimes::tar::TAR;
1699 use quantwave_core::traits::Next;
1700 use std::collections::HashMap;
1701
1702 #[test]
1703 fn test_basic_long_only_flip_on_synthetic() {
1704 let n: usize = 6;
1707 let timestamps: Vec<i64> = (0..n)
1708 .map(|i| 1_700_000_000i64 + (i as i64) * 3600)
1709 .collect(); let closes = vec![100.0, 101.0, 102.5, 103.0, 102.0, 101.0];
1711 let signals = vec![0.0, 1.0, 1.0, 1.0, 0.0, 0.0];
1712
1713 let df = DataFrame::new(vec![
1714 Column::new("timestamp".into(), timestamps),
1715 Column::new("close".into(), closes.clone()),
1716 Column::new("signal".into(), signals),
1717 ])
1718 .unwrap();
1719
1720 let result = backtest_simple_bool_signal(df, "signal").expect("sim should succeed");
1721
1722 assert_eq!(result.trades.height(), 1);
1724 let num_trades: f64 = *result.stats.get("num_trades").unwrap();
1725 assert_relative_eq!(num_trades, 1.0, epsilon = 1e-9);
1726
1727 let final_eq = *result.stats.get("final_equity").unwrap();
1729 let init = 100_000.0;
1730 assert!(
1731 final_eq > init,
1732 "equity should grow on winning long: {} vs {}",
1733 final_eq,
1734 init
1735 );
1736
1737 assert_eq!(result.equity_curve.height(), n);
1739
1740 let last_equity = result
1742 .equity_curve
1743 .column("equity")
1744 .unwrap()
1745 .f64()
1746 .unwrap()
1747 .get(n - 1)
1748 .unwrap();
1749 assert_relative_eq!(last_equity, final_eq, epsilon = 1e-6);
1750 }
1751
1752 #[test]
1753 fn test_flat_always_signal_produces_no_trades_and_flat_equity() {
1754 let n: usize = 5;
1755 let ts: Vec<i64> = (0..n).map(|i| 1_700_000_100 + i as i64).collect();
1756 let closes = vec![100.0; n];
1757 let signals = vec![0.0; n];
1758
1759 let df = DataFrame::new(vec![
1760 Column::new("timestamp".into(), ts),
1761 Column::new("close".into(), closes),
1762 Column::new("signal".into(), signals),
1763 ])
1764 .unwrap();
1765
1766 let result = backtest_simple_bool_signal(df, "signal").unwrap();
1767
1768 assert_eq!(result.trades.height(), 0);
1769 let num = *result.stats.get("num_trades").unwrap();
1770 assert_relative_eq!(num, 0.0, epsilon = 1e-9);
1771
1772 let final_equity_val = *result.stats.get("final_equity").unwrap();
1774 assert_relative_eq!(final_equity_val, 100_000.0, epsilon = 1e-4);
1775 }
1776
1777 #[test]
1778 fn test_synthetic_with_small_random_walk_and_bool_signal_matches_manual_calc() {
1779 let mut rng = rand::thread_rng();
1781 let n: usize = 8;
1782 let mut price = 100.0_f64;
1783 let mut closes = Vec::with_capacity(n);
1784 let signals = vec![0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0]; let mut ts = Vec::with_capacity(n);
1786
1787 for i in 0..n {
1788 ts.push(1_700_000_200 + i as i64);
1789 closes.push(price);
1790 price += rng.gen_range(-0.8..1.2);
1791 }
1792
1793 let df = DataFrame::new(vec![
1794 Column::new("timestamp".into(), ts.clone()),
1795 Column::new("close".into(), closes.clone()),
1796 Column::new("signal".into(), signals.clone()),
1797 ])
1798 .unwrap();
1799
1800 let result = backtest_simple_bool_signal(df.clone(), "signal").unwrap();
1801
1802 let slip = 0.0002;
1804 let comm = 0.0005;
1805 let init = 100_000.0;
1806 let mut cash = init;
1807 let mut pos = 0.0;
1808 let mut entry = 0.0;
1809 let mut manual_equity = init;
1810
1811 for i in 0..n {
1812 let c = closes[i];
1813 let s = signals[i] > 0.0;
1814
1815 if s && pos == 0.0 {
1816 let fp = c * (1.0 + slip);
1817 cash -= fp * (1.0 + comm);
1818 pos = 1.0;
1819 entry = fp;
1820 } else if !s && pos > 0.0 {
1821 let fp = c * (1.0 - slip);
1822 cash += fp * (1.0 - comm);
1823 let _g = (fp - entry) * pos;
1824 let cost = fp * comm;
1825 cash += -cost; pos = 0.0;
1827 }
1828 manual_equity = cash + pos * c;
1829 }
1830
1831 let engine_final = *result.stats.get("final_equity").unwrap();
1832 assert_relative_eq!(engine_final, manual_equity, epsilon = 0.5);
1834 }
1835
1836 #[derive(Debug, Clone)]
1844 struct SyntheticPoleHeightDetector {
1845 window: Vec<f64>,
1846 max_len: usize,
1847 }
1848
1849 impl SyntheticPoleHeightDetector {
1850 fn new(max_len: usize) -> Self {
1851 Self {
1852 window: Vec::with_capacity(max_len),
1853 max_len,
1854 }
1855 }
1856 }
1857
1858 #[derive(Debug, Clone, Copy)]
1859 struct PoleOutput {
1860 pole_height: f64,
1861 _strength: f64, }
1863
1864 impl Next<f64> for SyntheticPoleHeightDetector {
1865 type Output = PoleOutput;
1866
1867 fn next(&mut self, price: f64) -> PoleOutput {
1868 self.window.push(price);
1869 if self.window.len() > self.max_len {
1870 self.window.remove(0);
1871 }
1872 let h = if self.window.len() >= 3 {
1873 let mn = self.window.iter().fold(f64::INFINITY, |a, &b| a.min(b));
1874 let mx = self.window.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1875 (mx - mn).max(0.1)
1876 } else {
1877 1.0
1878 };
1879 PoleOutput {
1880 pole_height: h,
1881 _strength: (h / 8.0).clamp(0.3, 1.0),
1882 }
1883 }
1884 }
1885
1886 #[derive(Debug, Clone)]
1890 struct RegimeFeaturePAStrategy {
1891 regime: TAR,
1892 cycle: CyberCycleFeatureExtractor,
1893 pa: SyntheticPoleHeightDetector,
1894 feat_thresh: f64,
1895 }
1896
1897 impl RegimeFeaturePAStrategy {
1898 fn new() -> Self {
1899 Self {
1900 regime: TAR::new(105.0), cycle: CyberCycleFeatureExtractor::new(14),
1902 pa: SyntheticPoleHeightDetector::new(6),
1903 feat_thresh: 0.02,
1904 }
1905 }
1906 }
1907
1908 impl Next<&Bar> for RegimeFeaturePAStrategy {
1909 type Output = StrategySignal;
1910
1911 fn next(&mut self, bar: &Bar) -> StrategySignal {
1912 let regime = self.regime.next(bar.close);
1913 let feat = self.cycle.next(bar.close);
1914 let pa = self.pa.next(bar.close);
1915
1916 let regime_ok = matches!(
1918 regime,
1919 MarketRegime::Steady | MarketRegime::Cluster(_) | MarketRegime::Bull
1920 );
1921 let feat_ok = feat.cycle_momentum.abs() > self.feat_thresh;
1922
1923 let exposure = if regime_ok && feat_ok {
1924 (pa.pole_height / 4.0).clamp(0.4, 2.2)
1926 } else {
1927 0.0
1928 };
1929
1930 let mut meta = HashMap::new();
1931 meta.insert("pole_height".to_string(), pa.pole_height);
1932 meta.insert("cycle_momentum".to_string(), feat.cycle_momentum);
1933 meta.insert("regime_ok".to_string(), if regime_ok { 1.0 } else { 0.0 });
1934
1935 StrategySignal {
1936 exposure,
1937 metadata: Some(meta),
1938 }
1939 }
1940 }
1941
1942 #[test]
1943 fn test_batch_vs_streaming_parity_regime_feature_rich_pa_pole_sizing() {
1944 let n: usize = 120;
1947 let mut timestamps = Vec::with_capacity(n);
1948 let mut closes = Vec::with_capacity(n);
1949 let mut price;
1950
1951 for i in 0..n {
1952 let secs = 1_700_000_500i64 + (i as i64) * 3600;
1953 timestamps.push(chrono::DateTime::<chrono::Utc>::from_timestamp(secs, 0).unwrap());
1954 let wave = (i as f64 * 0.18).sin() * 4.5;
1956 price = 101.5 + wave + (i as f64 * 0.008);
1957 closes.push(price);
1958 }
1959
1960 let bars: Vec<Bar> = timestamps
1961 .iter()
1962 .zip(closes.iter())
1963 .map(|(&ts, &close)| Bar { ts, close })
1964 .collect();
1965
1966 let mut batch_gen = RegimeFeaturePAStrategy::new();
1970 let mut exposures: Vec<f64> = Vec::with_capacity(n);
1971 for bar in &bars {
1972 let s = batch_gen.next(bar);
1973 exposures.push(s.exposure);
1974 }
1975
1976 let df = DataFrame::new(vec![
1977 Column::new(
1978 "timestamp".into(),
1979 timestamps.iter().map(|t| t.timestamp()).collect::<Vec<_>>(),
1980 ),
1981 Column::new("close".into(), closes.clone()),
1982 Column::new("signal".into(), exposures.clone()),
1983 ])
1984 .unwrap();
1985
1986 let batch_res = backtest_simple_bool_signal(df, "signal").expect("batch parity run");
1987
1988 let stream_gen = RegimeFeaturePAStrategy::new();
1990 let stream_res = run_streaming_simulation(&bars, stream_gen, BacktestConfig::default())
1991 .expect("streaming parity run");
1992
1993 let b_eq = batch_res
1996 .equity_curve
1997 .column("equity")
1998 .unwrap()
1999 .f64()
2000 .unwrap()
2001 .into_iter()
2002 .map(|v| v.unwrap_or(0.0))
2003 .collect::<Vec<_>>();
2004 let s_eq = stream_res
2005 .equity_curve
2006 .column("equity")
2007 .unwrap()
2008 .f64()
2009 .unwrap()
2010 .into_iter()
2011 .map(|v| v.unwrap_or(0.0))
2012 .collect::<Vec<_>>();
2013
2014 assert_eq!(b_eq.len(), s_eq.len(), "equity curve lengths must match");
2015 for (i, (b, s)) in b_eq.iter().zip(s_eq.iter()).enumerate() {
2016 approx::assert_relative_eq!(*b, *s, epsilon = 1e-8, max_relative = 1e-8);
2017 if (b - s).abs() > 1e-7 {
2019 panic!("equity diverged at bar {}: {} vs {}", i, b, s);
2020 }
2021 }
2022
2023 let keys = ["final_equity", "net_pnl", "num_trades"];
2025 for k in keys {
2026 let bv = *batch_res.stats.get(k).unwrap();
2027 let sv = *stream_res.stats.get(k).unwrap();
2028 approx::assert_relative_eq!(bv, sv, epsilon = 1e-6, max_relative = 1e-6);
2029 }
2030
2031 assert_eq!(
2033 batch_res.trades.height(),
2034 stream_res.trades.height(),
2035 "trade counts must match exactly for parity"
2036 );
2037
2038 assert!(
2041 batch_res.trades.height() >= 1,
2042 "parity test strategy must generate >=1 trade on synthetic data"
2043 );
2044
2045 }
2050}
2051
2052#[cfg(test)]
2061mod integration_example_between_epics {
2062 use super::*;
2063 use quantwave_core::features::HurstFeatureExtractor;
2065
2066 #[test]
2067 fn ml_features_feed_backtester_with_metadata() {
2068 let n = 60;
2069 let closes: Vec<f64> = (0..n).map(|i| 100.0 + i as f64 * 0.25).collect();
2070 let timestamps: Vec<i64> = (0..n).map(|i| 1_700_000_000i64 + i as i64).collect();
2072
2073 let mut h_ext = HurstFeatureExtractor::new(15);
2075 let mut exposures = Vec::new();
2076
2077 for &c in &closes {
2078 let f = h_ext.next(c);
2079 let regime_ok = true; let exposure = if regime_ok && f.persistence > 0.52 {
2081 1.0
2082 } else {
2083 0.0
2084 };
2085 exposures.push(exposure);
2086 }
2087
2088 let lf = df![
2090 "timestamp" => timestamps,
2091 "close" => closes,
2092 "exposure" => exposures,
2093 ]
2094 .unwrap()
2095 .lazy();
2096
2097 let config = BacktestConfig {
2098 signal_col: "exposure".to_string(),
2099 ..Default::default()
2100 };
2101
2102 let result = BacktestEngine::new(config).run(lf).unwrap();
2103
2104 println!(
2106 "Integration smoke test: {} trades produced using ML feature (Hurst) driven exposure",
2107 result.trades.height()
2108 );
2109 assert!(result.equity_curve.height() == n);
2110 }
2111
2112 #[test]
2113 fn test_initial_risk_position_sizer_with_pole_height_and_fraction() {
2114 let sizer = InitialRiskPositionSizer { initial_risk: 0.01, max_target_pct: 0.5 };
2116 let mut meta = HashMap::new();
2117 meta.insert("pole_height_atr".to_string(), 2.0); let sig = StrategySignal { exposure: 1.0, metadata: Some(meta) };
2119 let sized = sizer.compute_sized_exposure(1.0, &sig.metadata, 100.0, 1_000_000.0);
2120 assert!((sized - 5000.0).abs() < 1.0);
2123
2124 let mut meta2 = HashMap::new();
2126 meta2.insert("fraction_at_risk".to_string(), 0.02);
2127 let sig2 = StrategySignal { exposure: 1.0, metadata: Some(meta2) };
2128 let sized2 = sizer.compute_sized_exposure(1.0, &sig2.metadata, 100.0, 1_000_000.0);
2129 assert!((sized2 - 5000.0).abs() < 1.0);
2131
2132 let sig3 = StrategySignal { exposure: 123.0, metadata: None };
2134 let sized3 = sizer.compute_sized_exposure(123.0, &sig3.metadata, 100.0, 1_000_000.0);
2135 assert!((sized3 - 123.0).abs() < 1e-9);
2136 }
2137}