1use chrono::{DateTime, Utc};
67use polars::prelude::*;
68#[allow(unused_imports)]
69use quantwave_core::traits::Next; use serde::{Deserialize, Serialize};
71use std::collections::HashMap;
72use thiserror::Error;
73
74#[derive(Error, Debug)]
76pub enum BacktestError {
77 #[error("Polars error during simulation: {0}")]
78 Polars(#[from] PolarsError),
79
80 #[error("Invalid input: {0}")]
81 InvalidInput(String),
82
83 #[error("Data must be sorted by timestamp (and symbol for multi-symbol runs)")]
84 UnsortedData,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct CostModel {
90 pub commission_bps: f64,
92 pub slippage_bps: f64,
94 pub initial_cash: f64,
96}
97
98impl Default for CostModel {
99 fn default() -> Self {
100 Self {
101 commission_bps: 5.0, slippage_bps: 2.0, initial_cash: 100_000.0,
104 }
105 }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct BacktestConfig {
111 pub cost_model: CostModel,
112 pub timestamp_col: String,
114 pub symbol_col: Option<String>,
115 pub close_col: String,
116 pub signal_col: String,
124 pub entry_filter_col: Option<String>,
128 pub size_multiplier_col: Option<String>,
131}
132
133impl Default for BacktestConfig {
134 fn default() -> Self {
135 Self {
136 cost_model: CostModel::default(),
137 timestamp_col: "timestamp".to_string(),
138 symbol_col: None,
139 close_col: "close".to_string(),
140 signal_col: "signal".to_string(),
141 entry_filter_col: None,
142 size_multiplier_col: None,
143 }
144 }
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct Trade {
150 pub trade_id: u32,
151 pub symbol: Option<String>,
152 pub side: i8, pub entry_ts: DateTime<Utc>,
154 pub entry_price: f64,
155 pub entry_fill_price: f64, pub exit_ts: Option<DateTime<Utc>>,
157 pub exit_price: Option<f64>,
158 pub exit_fill_price: Option<f64>,
159 pub pnl_gross: f64,
160 pub costs: f64,
161 pub pnl_net: f64,
162 pub quantity: f64,
165 pub entry_metadata: Option<HashMap<String, f64>>,
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct EquityPoint {
173 pub ts: DateTime<Utc>,
174 pub symbol: Option<String>, pub equity: f64,
176 pub cash: f64,
177 pub position: f64, pub close: f64,
179}
180
181#[derive(Debug)]
183pub struct BacktestResult {
184 pub trades: DataFrame,
186 pub equity_curve: DataFrame,
188 pub stats: HashMap<String, f64>,
191}
192
193#[derive(Debug, Clone)]
196pub struct Bar {
197 pub ts: DateTime<Utc>,
198 pub close: f64,
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct StrategySignal {
206 pub exposure: f64,
209 pub metadata: Option<HashMap<String, f64>>,
212}
213
214impl Default for StrategySignal {
215 fn default() -> Self {
216 Self {
217 exposure: 0.0,
218 metadata: None,
219 }
220 }
221}
222
223#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
227pub struct PAEvent {
228 pub long: bool,
230 pub pole_height: Option<f64>,
232 pub strength: Option<f64>,
234}
235
236pub struct BacktestEngine {
247 config: BacktestConfig,
248}
249
250impl BacktestEngine {
251 pub fn new(config: BacktestConfig) -> Self {
252 Self { config }
253 }
254
255 pub fn with_default_costs() -> Self {
256 Self::new(BacktestConfig::default())
257 }
258
259 pub fn run(&self, lf: LazyFrame) -> Result<BacktestResult, BacktestError> {
263 let df = lf.collect()?;
264
265 if df.height() == 0 {
266 return Err(BacktestError::InvalidInput("empty dataframe".into()));
267 }
268
269 let ts_col = &self.config.timestamp_col;
271 let close_col = &self.config.close_col;
272 let sig_col = &self.config.signal_col;
273
274 for c in [ts_col, close_col, sig_col] {
275 if df.column(c).is_err() {
276 return Err(BacktestError::InvalidInput(format!("missing column: {}", c)));
277 }
278 }
279
280 let ts_series = df.column(ts_col)?.clone();
282 let close_ca = df.column(close_col)?.f64()?.clone();
283 let signal_series = df.column(sig_col)?;
284
285 let signal_vals: Vec<f64> = if signal_series.dtype().is_bool() {
288 signal_series
289 .bool()?
290 .into_iter()
291 .map(|b| if b.unwrap_or(false) { 1.0 } else { 0.0 })
292 .collect()
293 } else {
294 signal_series
295 .f64()?
296 .into_iter()
297 .map(|v| v.unwrap_or(0.0))
298 .collect()
299 };
300
301 let timestamps: Vec<DateTime<Utc>> = self.extract_timestamps(&ts_series)?;
303
304 let closes: Vec<f64> = close_ca.into_iter().map(|v| v.unwrap_or(f64::NAN)).collect();
305
306 if timestamps.len() != closes.len() || closes.len() != signal_vals.len() {
307 return Err(BacktestError::InvalidInput("column length mismatch".into()));
308 }
309
310 let cm = &self.config.cost_model;
313 let metas: Vec<Option<HashMap<String, f64>>> = vec![None; signal_vals.len()];
314 let (trades, equity_points) = run_simulation(
315 ×tamps,
316 &closes,
317 |i| (signal_vals[i], metas[i].clone()),
318 cm,
319 );
320
321 let trades_df = self.trades_to_df(&trades)?;
323 let equity_df = self.equity_to_df(&equity_points)?;
324
325 let final_equity = equity_points.last().map(|e| e.equity).unwrap_or(cm.initial_cash);
327 let total_return = (final_equity - cm.initial_cash) / cm.initial_cash;
328 let num_trades = trades.len() as f64;
329
330 let mut stats = HashMap::new();
331 stats.insert("initial_cash".to_string(), cm.initial_cash);
332 stats.insert("final_equity".to_string(), final_equity);
333 stats.insert("total_return".to_string(), total_return);
334 stats.insert("num_trades".to_string(), num_trades);
335 stats.insert("net_pnl".to_string(), final_equity - cm.initial_cash);
336
337 Ok(BacktestResult {
338 trades: trades_df,
339 equity_curve: equity_df,
340 stats,
341 })
342 }
343
344 fn extract_timestamps(&self, col: &Column) -> Result<Vec<DateTime<Utc>>, BacktestError> {
345 let s = col.as_series().ok_or_else(|| BacktestError::InvalidInput("column has no series backing".into()))?;
348
349 if let Ok(ca) = s.datetime() {
351 return Ok(ca
352 .into_iter()
353 .map(|opt| {
354 opt.map(|v| {
355 let secs = v / 1000;
357 let nanos = ((v % 1000) * 1_000_000) as u32;
358 DateTime::<Utc>::from_timestamp(secs, nanos).unwrap_or_else(Utc::now)
359 })
360 .unwrap_or_else(Utc::now)
361 })
362 .collect());
363 }
364
365 if let Ok(ca) = s.i64() {
366 return Ok(ca
368 .into_iter()
369 .enumerate()
370 .map(|(i, opt)| {
371 let v = opt.unwrap_or(i as i64);
372 DateTime::<Utc>::from_timestamp(v, 0).unwrap_or_else(Utc::now)
373 })
374 .collect());
375 }
376
377 Err(BacktestError::InvalidInput(
379 "timestamp column must be Datetime or Int64 for this MVP".into(),
380 ))
381 }
382
383 fn trades_to_df(&self, trades: &[Trade]) -> Result<DataFrame, PolarsError> {
384 if trades.is_empty() {
385 return Ok(DataFrame::new(vec![
387 Column::new("trade_id".into(), Vec::<u32>::new()),
388 Column::new("side".into(), Vec::<i8>::new()),
389 Column::new("entry_ts".into(), Vec::<i64>::new()),
390 Column::new("entry_price".into(), Vec::<f64>::new()),
391 Column::new("pnl_net".into(), Vec::<f64>::new()),
392 ])?);
393 }
394
395 let ids: Vec<u32> = trades.iter().map(|t| t.trade_id).collect();
396 let sides: Vec<i8> = trades.iter().map(|t| t.side).collect();
397 let entry_ts: Vec<i64> = trades.iter().map(|t| t.entry_ts.timestamp()).collect();
398 let entry_px: Vec<f64> = trades.iter().map(|t| t.entry_price).collect();
399 let exit_ts: Vec<Option<i64>> = trades
400 .iter()
401 .map(|t| t.exit_ts.map(|d| d.timestamp()))
402 .collect();
403 let pnl: Vec<f64> = trades.iter().map(|t| t.pnl_net).collect();
404
405 DataFrame::new(vec![
406 Column::new("trade_id".into(), ids),
407 Column::new("side".into(), sides),
408 Column::new("entry_ts".into(), entry_ts),
409 Column::new("entry_price".into(), entry_px),
410 Column::new("exit_ts".into(), exit_ts),
411 Column::new("pnl_net".into(), pnl),
412 ])
413 }
414
415 fn equity_to_df(&self, points: &[EquityPoint]) -> Result<DataFrame, PolarsError> {
416 if points.is_empty() {
417 return Ok(DataFrame::new(vec![
418 Column::new("ts".into(), Vec::<i64>::new()),
419 Column::new("equity".into(), Vec::<f64>::new()),
420 Column::new("position".into(), Vec::<f64>::new()),
421 ])?);
422 }
423
424 let ts: Vec<i64> = points.iter().map(|p| p.ts.timestamp()).collect();
425 let eq: Vec<f64> = points.iter().map(|p| p.equity).collect();
426 let pos: Vec<f64> = points.iter().map(|p| p.position).collect();
427 let cash: Vec<f64> = points.iter().map(|p| p.cash).collect();
428 let close: Vec<f64> = points.iter().map(|p| p.close).collect();
429
430 DataFrame::new(vec![
431 Column::new("ts".into(), ts),
432 Column::new("equity".into(), eq),
433 Column::new("cash".into(), cash),
434 Column::new("position".into(), pos),
435 Column::new("close".into(), close),
436 ])
437 }
438}
439
440pub fn backtest_simple_bool_signal(
443 ohlcv: DataFrame,
444 signal_col: &str,
445) -> Result<BacktestResult, BacktestError> {
446 let config = BacktestConfig {
447 signal_col: signal_col.to_string(),
448 ..Default::default()
449 };
450 let engine = BacktestEngine::new(config);
451 engine.run(ohlcv.lazy())
452}
453
454fn run_simulation(
462 timestamps: &[DateTime<Utc>],
463 closes: &[f64],
464 mut next_signal: impl FnMut(usize) -> (f64, Option<HashMap<String, f64>>),
465 cm: &CostModel,
466) -> (Vec<Trade>, Vec<EquityPoint>) {
467 let slip = cm.slippage_bps / 10000.0;
468 let comm = cm.commission_bps / 10000.0;
469
470 let mut cash = cm.initial_cash;
471 let mut current_exposure: f64 = 0.0;
472 let mut entry_price: f64 = 0.0;
473 let mut entry_ts: Option<DateTime<Utc>> = None;
474 let mut trade_id: u32 = 0;
475 let mut trades: Vec<Trade> = Vec::new();
476 let mut equity_points: Vec<EquityPoint> = Vec::with_capacity(closes.len());
477
478 for i in 0..closes.len() {
479 let close = closes[i];
480 if !close.is_finite() {
481 let equity = cash + current_exposure * close;
482 equity_points.push(EquityPoint {
483 ts: timestamps[i],
484 symbol: None,
485 equity,
486 cash,
487 position: current_exposure,
488 close,
489 });
490 continue;
491 }
492
493 let (desired_exposure, meta) = next_signal(i);
494 let desired = if desired_exposure > 0.0 { desired_exposure } else { 0.0 };
495
496 let currently_in = current_exposure > 0.0;
498
499 if desired > 0.0 && !currently_in {
500 let fill_price = close * (1.0 + slip);
502 let notional = fill_price * desired;
503 let cost = notional * comm;
504 cash -= notional + cost;
505 current_exposure = desired;
506 entry_price = fill_price;
507 entry_ts = Some(timestamps[i]);
508 trade_id += 1;
509 } else if desired == 0.0 && currently_in {
510 let fill_price = close * (1.0 - slip);
512 let notional = fill_price * current_exposure;
513 let cost = notional * comm;
514 let gross_pnl = (fill_price - entry_price) * current_exposure;
515 let net_pnl = gross_pnl - cost;
516 cash += notional - cost;
517
518 if let Some(ets) = entry_ts {
519 trades.push(Trade {
520 trade_id,
521 symbol: None,
522 side: 1,
523 entry_ts: ets,
524 entry_price,
525 entry_fill_price: entry_price,
526 exit_ts: Some(timestamps[i]),
527 exit_price: Some(close),
528 exit_fill_price: Some(fill_price),
529 pnl_gross: gross_pnl,
530 costs: cost,
531 pnl_net: net_pnl,
532 quantity: current_exposure,
533 entry_metadata: meta.clone(),
534 });
535 }
536 current_exposure = 0.0;
537 entry_price = 0.0;
538 entry_ts = None;
539 }
540
541 let equity = cash + current_exposure * close;
542 equity_points.push(EquityPoint {
543 ts: timestamps[i],
544 symbol: None,
545 equity,
546 cash,
547 position: current_exposure,
548 close,
549 });
550 }
551
552 if current_exposure > 0.0 {
554 let last_close = *closes.last().unwrap();
555 let gross = (last_close - entry_price) * current_exposure;
556 if let Some(ets) = entry_ts {
557 trades.push(Trade {
558 trade_id,
559 symbol: None,
560 side: 1,
561 entry_ts: ets,
562 entry_price,
563 entry_fill_price: entry_price,
564 exit_ts: None,
565 exit_price: Some(last_close),
566 exit_fill_price: None,
567 pnl_gross: gross,
568 costs: 0.0,
569 pnl_net: gross,
570 quantity: current_exposure,
571 entry_metadata: None, });
573 }
574 }
575
576 (trades, equity_points)
577}
578
579pub fn run_streaming_simulation<G>(
586 bars: &[Bar],
587 mut generator: G,
588 config: BacktestConfig,
589) -> Result<BacktestResult, BacktestError>
590where
591 G: for<'a> Next<&'a Bar, Output = StrategySignal>,
592{
593 if bars.is_empty() {
594 return Err(BacktestError::InvalidInput("empty bars".into()));
595 }
596
597 let timestamps: Vec<DateTime<Utc>> = bars.iter().map(|b| b.ts).collect();
598 let closes: Vec<f64> = bars.iter().map(|b| b.close).collect();
599
600 let cm = &config.cost_model;
601
602 let (trades, equity_points) = run_simulation(
603 ×tamps,
604 &closes,
605 |i| {
606 let sig = generator.next(&bars[i]);
607 (sig.exposure, sig.metadata.clone())
608 },
609 cm,
610 );
611
612 let trades_df = if trades.is_empty() {
617 DataFrame::new(vec![
618 Column::new("trade_id".into(), Vec::<u32>::new()),
619 Column::new("side".into(), Vec::<i8>::new()),
620 Column::new("entry_ts".into(), Vec::<i64>::new()),
621 Column::new("entry_price".into(), Vec::<f64>::new()),
622 Column::new("pnl_net".into(), Vec::<f64>::new()),
623 ])?
624 } else {
625 let ids: Vec<u32> = trades.iter().map(|t| t.trade_id).collect();
626 let sides: Vec<i8> = trades.iter().map(|t| t.side).collect();
627 let entry_ts: Vec<i64> = trades.iter().map(|t| t.entry_ts.timestamp()).collect();
628 let entry_px: Vec<f64> = trades.iter().map(|t| t.entry_price).collect();
629 let exit_ts: Vec<Option<i64>> = trades
630 .iter()
631 .map(|t| t.exit_ts.map(|d| d.timestamp()))
632 .collect();
633 let pnl: Vec<f64> = trades.iter().map(|t| t.pnl_net).collect();
634
635 DataFrame::new(vec![
636 Column::new("trade_id".into(), ids),
637 Column::new("side".into(), sides),
638 Column::new("entry_ts".into(), entry_ts),
639 Column::new("entry_price".into(), entry_px),
640 Column::new("exit_ts".into(), exit_ts),
641 Column::new("pnl_net".into(), pnl),
642 ])?
643 };
644
645 let equity_df = if equity_points.is_empty() {
646 DataFrame::new(vec![
647 Column::new("ts".into(), Vec::<i64>::new()),
648 Column::new("equity".into(), Vec::<f64>::new()),
649 Column::new("position".into(), Vec::<f64>::new()),
650 ])?
651 } else {
652 let ts: Vec<i64> = equity_points.iter().map(|p| p.ts.timestamp()).collect();
653 let eq: Vec<f64> = equity_points.iter().map(|p| p.equity).collect();
654 let pos: Vec<f64> = equity_points.iter().map(|p| p.position).collect();
655 let cash: Vec<f64> = equity_points.iter().map(|p| p.cash).collect();
656 let close: Vec<f64> = equity_points.iter().map(|p| p.close).collect();
657
658 DataFrame::new(vec![
659 Column::new("ts".into(), ts),
660 Column::new("equity".into(), eq),
661 Column::new("cash".into(), cash),
662 Column::new("position".into(), pos),
663 Column::new("close".into(), close),
664 ])?
665 };
666
667 let final_equity = equity_points.last().map(|e| e.equity).unwrap_or(cm.initial_cash);
668 let total_return = (final_equity - cm.initial_cash) / cm.initial_cash;
669 let num_trades = trades.len() as f64;
670
671 let mut stats = HashMap::new();
672 stats.insert("initial_cash".to_string(), cm.initial_cash);
673 stats.insert("final_equity".to_string(), final_equity);
674 stats.insert("total_return".to_string(), total_return);
675 stats.insert("num_trades".to_string(), num_trades);
676 stats.insert("net_pnl".to_string(), final_equity - cm.initial_cash);
677
678 Ok(BacktestResult {
679 trades: trades_df,
680 equity_curve: equity_df,
681 stats,
682 })
683}
684
685#[cfg(test)]
686mod tests {
687 use super::*;
688 use approx::assert_relative_eq;
689 use polars::prelude::*;
690 use rand::Rng;
691 use quantwave_core::features::CyberCycleFeatureExtractor;
693 use quantwave_core::regimes::tar::TAR;
694 use quantwave_core::regimes::MarketRegime;
695 use quantwave_core::traits::Next;
696 use std::collections::HashMap;
697
698 #[test]
699 fn test_basic_long_only_flip_on_synthetic() {
700 let n: usize = 6;
703 let timestamps: Vec<i64> = (0..n).map(|i| 1_700_000_000i64 + (i as i64) * 3600).collect(); let closes = vec![100.0, 101.0, 102.5, 103.0, 102.0, 101.0];
705 let signals = vec![0.0, 1.0, 1.0, 1.0, 0.0, 0.0];
706
707 let df = DataFrame::new(vec![
708 Column::new("timestamp".into(), timestamps),
709 Column::new("close".into(), closes.clone()),
710 Column::new("signal".into(), signals),
711 ])
712 .unwrap();
713
714 let result = backtest_simple_bool_signal(df, "signal").expect("sim should succeed");
715
716 assert_eq!(result.trades.height(), 1);
718 let num_trades: f64 = *result.stats.get("num_trades").unwrap();
719 assert_relative_eq!(num_trades, 1.0, epsilon = 1e-9);
720
721 let final_eq = *result.stats.get("final_equity").unwrap();
723 let init = 100_000.0;
724 assert!(final_eq > init, "equity should grow on winning long: {} vs {}", final_eq, init);
725
726 assert_eq!(result.equity_curve.height(), n);
728
729 let last_equity = result
731 .equity_curve
732 .column("equity")
733 .unwrap()
734 .f64()
735 .unwrap()
736 .get(n - 1)
737 .unwrap();
738 assert_relative_eq!(last_equity, final_eq, epsilon = 1e-6);
739 }
740
741 #[test]
742 fn test_flat_always_signal_produces_no_trades_and_flat_equity() {
743 let n: usize = 5;
744 let ts: Vec<i64> = (0..n).map(|i| 1_700_000_100 + i as i64).collect();
745 let closes = vec![100.0; n];
746 let signals = vec![0.0; n];
747
748 let df = DataFrame::new(vec![
749 Column::new("timestamp".into(), ts),
750 Column::new("close".into(), closes),
751 Column::new("signal".into(), signals),
752 ])
753 .unwrap();
754
755 let result = backtest_simple_bool_signal(df, "signal").unwrap();
756
757 assert_eq!(result.trades.height(), 0);
758 let num = *result.stats.get("num_trades").unwrap();
759 assert_relative_eq!(num, 0.0, epsilon = 1e-9);
760
761 let final_equity_val = *result.stats.get("final_equity").unwrap();
763 assert_relative_eq!(final_equity_val, 100_000.0, epsilon = 1e-4);
764 }
765
766 #[test]
767 fn test_synthetic_with_small_random_walk_and_bool_signal_matches_manual_calc() {
768 let mut rng = rand::thread_rng();
770 let n: usize = 8;
771 let mut price = 100.0_f64;
772 let mut closes = Vec::with_capacity(n);
773 let signals = vec![0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0]; let mut ts = Vec::with_capacity(n);
775
776 for i in 0..n {
777 ts.push(1_700_000_200 + i as i64);
778 closes.push(price);
779 price += rng.gen_range(-0.8..1.2);
780 }
781
782 let df = DataFrame::new(vec![
783 Column::new("timestamp".into(), ts.clone()),
784 Column::new("close".into(), closes.clone()),
785 Column::new("signal".into(), signals.clone()),
786 ])
787 .unwrap();
788
789 let result = backtest_simple_bool_signal(df.clone(), "signal").unwrap();
790
791 let slip = 0.0002;
793 let comm = 0.0005;
794 let init = 100_000.0;
795 let mut cash = init;
796 let mut pos = 0.0;
797 let mut entry = 0.0;
798 let mut manual_equity = init;
799
800 for i in 0..n {
801 let c = closes[i];
802 let s = signals[i] > 0.0;
803
804 if s && pos == 0.0 {
805 let fp = c * (1.0 + slip);
806 cash -= fp * (1.0 + comm);
807 pos = 1.0;
808 entry = fp;
809 } else if !s && pos > 0.0 {
810 let fp = c * (1.0 - slip);
811 cash += fp * (1.0 - comm);
812 let _g = (fp - entry) * pos;
813 let cost = fp * comm;
814 cash += -cost; pos = 0.0;
816 }
817 manual_equity = cash + pos * c;
818 }
819
820 let engine_final = *result.stats.get("final_equity").unwrap();
821 assert_relative_eq!(engine_final, manual_equity, epsilon = 0.5);
823 }
824
825 #[derive(Debug, Clone)]
833 struct SyntheticPoleHeightDetector {
834 window: Vec<f64>,
835 max_len: usize,
836 }
837
838 impl SyntheticPoleHeightDetector {
839 fn new(max_len: usize) -> Self {
840 Self {
841 window: Vec::with_capacity(max_len),
842 max_len,
843 }
844 }
845 }
846
847 #[derive(Debug, Clone, Copy)]
848 struct PoleOutput {
849 pole_height: f64,
850 _strength: f64, }
852
853 impl Next<f64> for SyntheticPoleHeightDetector {
854 type Output = PoleOutput;
855
856 fn next(&mut self, price: f64) -> PoleOutput {
857 self.window.push(price);
858 if self.window.len() > self.max_len {
859 self.window.remove(0);
860 }
861 let h = if self.window.len() >= 3 {
862 let mn = self.window.iter().fold(f64::INFINITY, |a, &b| a.min(b));
863 let mx = self.window.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
864 (mx - mn).max(0.1)
865 } else {
866 1.0
867 };
868 PoleOutput {
869 pole_height: h,
870 _strength: (h / 8.0).clamp(0.3, 1.0),
871 }
872 }
873 }
874
875 #[derive(Debug, Clone)]
879 struct RegimeFeaturePAStrategy {
880 regime: TAR,
881 cycle: CyberCycleFeatureExtractor,
882 pa: SyntheticPoleHeightDetector,
883 feat_thresh: f64,
884 }
885
886 impl RegimeFeaturePAStrategy {
887 fn new() -> Self {
888 Self {
889 regime: TAR::new(105.0), cycle: CyberCycleFeatureExtractor::new(14),
891 pa: SyntheticPoleHeightDetector::new(6),
892 feat_thresh: 0.02,
893 }
894 }
895 }
896
897 impl Next<&Bar> for RegimeFeaturePAStrategy {
898 type Output = StrategySignal;
899
900 fn next(&mut self, bar: &Bar) -> StrategySignal {
901 let regime = self.regime.next(bar.close);
902 let feat = self.cycle.next(bar.close);
903 let pa = self.pa.next(bar.close);
904
905 let regime_ok = matches!(
907 regime,
908 MarketRegime::Steady | MarketRegime::Cluster(_) | MarketRegime::Bull
909 );
910 let feat_ok = feat.cycle_momentum.abs() > self.feat_thresh;
911
912 let exposure = if regime_ok && feat_ok {
913 (pa.pole_height / 4.0).clamp(0.4, 2.2)
915 } else {
916 0.0
917 };
918
919 let mut meta = HashMap::new();
920 meta.insert("pole_height".to_string(), pa.pole_height);
921 meta.insert("cycle_momentum".to_string(), feat.cycle_momentum);
922 meta.insert(
923 "regime_ok".to_string(),
924 if regime_ok { 1.0 } else { 0.0 },
925 );
926
927 StrategySignal {
928 exposure,
929 metadata: Some(meta),
930 }
931 }
932 }
933
934 #[test]
935 fn test_batch_vs_streaming_parity_regime_feature_rich_pa_pole_sizing() {
936 let n: usize = 120;
939 let mut timestamps = Vec::with_capacity(n);
940 let mut closes = Vec::with_capacity(n);
941 let mut price = 100.0_f64;
942
943 for i in 0..n {
944 let secs = 1_700_000_500i64 + (i as i64) * 3600;
945 timestamps.push(chrono::DateTime::<chrono::Utc>::from_timestamp(secs, 0).unwrap());
946 let wave = (i as f64 * 0.18).sin() * 4.5;
948 price = 101.5 + wave + (i as f64 * 0.008);
949 closes.push(price);
950 }
951
952 let bars: Vec<Bar> = timestamps
953 .iter()
954 .zip(closes.iter())
955 .map(|(&ts, &close)| Bar { ts, close })
956 .collect();
957
958 let mut batch_gen = RegimeFeaturePAStrategy::new();
962 let mut exposures: Vec<f64> = Vec::with_capacity(n);
963 for bar in &bars {
964 let s = batch_gen.next(bar);
965 exposures.push(s.exposure);
966 }
967
968 let df = DataFrame::new(vec![
969 Column::new("timestamp".into(), timestamps.iter().map(|t| t.timestamp()).collect::<Vec<_>>()),
970 Column::new("close".into(), closes.clone()),
971 Column::new("signal".into(), exposures.clone()),
972 ])
973 .unwrap();
974
975 let batch_res = backtest_simple_bool_signal(df, "signal").expect("batch parity run");
976
977 let stream_gen = RegimeFeaturePAStrategy::new();
979 let stream_res = run_streaming_simulation(&bars, stream_gen, BacktestConfig::default())
980 .expect("streaming parity run");
981
982 let b_eq = batch_res
985 .equity_curve
986 .column("equity")
987 .unwrap()
988 .f64()
989 .unwrap()
990 .into_iter()
991 .map(|v| v.unwrap_or(0.0))
992 .collect::<Vec<_>>();
993 let s_eq = stream_res
994 .equity_curve
995 .column("equity")
996 .unwrap()
997 .f64()
998 .unwrap()
999 .into_iter()
1000 .map(|v| v.unwrap_or(0.0))
1001 .collect::<Vec<_>>();
1002
1003 assert_eq!(b_eq.len(), s_eq.len(), "equity curve lengths must match");
1004 for (i, (b, s)) in b_eq.iter().zip(s_eq.iter()).enumerate() {
1005 approx::assert_relative_eq!(
1006 *b,
1007 *s,
1008 epsilon = 1e-8,
1009 max_relative = 1e-8
1010 );
1011 if (b - s).abs() > 1e-7 {
1013 panic!("equity diverged at bar {}: {} vs {}", i, b, s);
1014 }
1015 }
1016
1017 let keys = ["final_equity", "net_pnl", "num_trades"];
1019 for k in keys {
1020 let bv = *batch_res.stats.get(k).unwrap();
1021 let sv = *stream_res.stats.get(k).unwrap();
1022 approx::assert_relative_eq!(bv, sv, epsilon = 1e-6, max_relative = 1e-6);
1023 }
1024
1025 assert_eq!(
1027 batch_res.trades.height(),
1028 stream_res.trades.height(),
1029 "trade counts must match exactly for parity"
1030 );
1031
1032 assert!(
1035 batch_res.trades.height() >= 1,
1036 "parity test strategy must generate >=1 trade on synthetic data"
1037 );
1038
1039 }
1044}
1045
1046#[cfg(test)]
1055mod integration_example_between_epics {
1056 use super::*;
1057 use polars::prelude::*;
1058 use quantwave_core::features::HurstFeatureExtractor;
1059
1060 #[test]
1061 fn ml_features_feed_backtester_with_metadata() {
1062 let n = 60;
1063 let closes: Vec<f64> = (0..n).map(|i| 100.0 + i as f64 * 0.25).collect();
1064 let timestamps: Vec<i64> = (0..n).map(|i| 1_700_000_000i64 + i as i64).collect();
1066
1067 let mut h_ext = HurstFeatureExtractor::new(15);
1069 let mut exposures = Vec::new();
1070
1071 for &c in &closes {
1072 let f = h_ext.next(c);
1073 let regime_ok = true; let exposure = if regime_ok && f.persistence > 0.52 { 1.0 } else { 0.0 };
1075 exposures.push(exposure);
1076 }
1077
1078 let lf = df![
1080 "timestamp" => timestamps,
1081 "close" => closes,
1082 "exposure" => exposures,
1083 ]
1084 .unwrap()
1085 .lazy();
1086
1087 let config = BacktestConfig {
1088 signal_col: "exposure".to_string(),
1089 ..Default::default()
1090 };
1091
1092 let result = BacktestEngine::new(config).run(lf).unwrap();
1093
1094 println!(
1096 "Integration smoke test: {} trades produced using ML feature (Hurst) driven exposure",
1097 result.trades.height()
1098 );
1099 assert!(result.equity_curve.height() == n);
1100 }
1101}