1use std::collections::HashMap;
4
5use crate::indicators::{self, Indicator};
6use crate::models::chart::Candle;
7
8use super::config::BacktestConfig;
9use super::error::{BacktestError, Result};
10use super::position::{Position, PositionSide, Trade};
11use super::result::{BacktestResult, EquityPoint, PerformanceMetrics, SignalRecord};
12use super::signal::{Signal, SignalDirection};
13use super::strategy::{Strategy, StrategyContext};
14
15pub struct BacktestEngine {
19 config: BacktestConfig,
20}
21
22impl BacktestEngine {
23 pub fn new(config: BacktestConfig) -> Self {
25 Self { config }
26 }
27
28 pub fn run<S: Strategy>(
30 &self,
31 symbol: &str,
32 candles: &[Candle],
33 strategy: S,
34 ) -> Result<BacktestResult> {
35 let warmup = strategy.warmup_period();
36 if candles.len() < warmup {
37 return Err(BacktestError::insufficient_data(warmup, candles.len()));
38 }
39
40 let indicators = self.compute_indicators(candles, &strategy)?;
42
43 let mut equity = self.config.initial_capital;
45 let mut cash = self.config.initial_capital;
46 let mut position: Option<Position> = None;
47 let mut trades: Vec<Trade> = Vec::new();
48 let mut equity_curve: Vec<EquityPoint> = Vec::new();
49 let mut signals: Vec<SignalRecord> = Vec::new();
50 let mut peak_equity = equity;
51
52 for i in 0..candles.len() {
54 let candle = &candles[i];
55
56 if let Some(ref pos) = position {
58 let pos_value = pos.current_value(candle.close);
59 equity = cash + pos_value;
60 } else {
61 equity = cash;
62 }
63
64 if equity > peak_equity {
66 peak_equity = equity;
67 }
68 let drawdown_pct = if peak_equity > 0.0 {
69 (peak_equity - equity) / peak_equity
70 } else {
71 0.0
72 };
73
74 equity_curve.push(EquityPoint {
75 timestamp: candle.timestamp,
76 equity,
77 drawdown_pct,
78 });
79
80 if let Some(ref pos) = position
82 && let Some(exit_signal) = self.check_sl_tp(pos, candle)
83 {
84 let exit_price = self.config.apply_exit_slippage(candle.close, pos.is_long());
85 let exit_commission = self.config.calculate_commission(exit_price * pos.quantity);
86
87 signals.push(SignalRecord {
88 timestamp: candle.timestamp,
89 price: candle.close,
90 direction: SignalDirection::Exit,
91 strength: 1.0,
92 reason: exit_signal.reason.clone(),
93 executed: true,
94 });
95
96 let trade = position.take().unwrap().close(
97 candle.timestamp,
98 exit_price,
99 exit_commission,
100 exit_signal,
101 );
102
103 cash += trade.entry_value() + trade.pnl;
104 trades.push(trade);
105 continue; }
107
108 if i < warmup.saturating_sub(1) {
110 continue;
111 }
112
113 let ctx = StrategyContext {
115 candles: &candles[..=i],
116 index: i,
117 position: position.as_ref(),
118 equity,
119 indicators: &indicators,
120 };
121
122 let signal = strategy.on_candle(&ctx);
124
125 if signal.is_hold() {
127 continue;
128 }
129
130 if signal.strength.value() < self.config.min_signal_strength {
132 signals.push(SignalRecord {
133 timestamp: signal.timestamp,
134 price: signal.price,
135 direction: signal.direction,
136 strength: signal.strength.value(),
137 reason: signal.reason.clone(),
138 executed: false,
139 });
140 continue;
141 }
142
143 let executed =
145 self.execute_signal(&signal, candle, &mut position, &mut cash, &mut trades);
146
147 signals.push(SignalRecord {
148 timestamp: signal.timestamp,
149 price: signal.price,
150 direction: signal.direction,
151 strength: signal.strength.value(),
152 reason: signal.reason,
153 executed,
154 });
155 }
156
157 if self.config.close_at_end
159 && let Some(pos) = position.take()
160 {
161 let last_candle = candles.last().unwrap();
162 let exit_price = self
163 .config
164 .apply_exit_slippage(last_candle.close, pos.is_long());
165 let exit_commission = self.config.calculate_commission(exit_price * pos.quantity);
166
167 let exit_signal = Signal::exit(last_candle.timestamp, last_candle.close)
168 .with_reason("End of backtest");
169
170 let trade = pos.close(
171 last_candle.timestamp,
172 exit_price,
173 exit_commission,
174 exit_signal,
175 );
176 cash += trade.entry_value() + trade.pnl;
177 trades.push(trade);
178 }
179
180 let final_equity = if let Some(ref pos) = position {
182 cash + pos.current_value(candles.last().unwrap().close)
183 } else {
184 cash
185 };
186
187 let executed_signals = signals.iter().filter(|s| s.executed).count();
189 let metrics = PerformanceMetrics::calculate(
190 &trades,
191 &equity_curve,
192 self.config.initial_capital,
193 signals.len(),
194 executed_signals,
195 );
196
197 let start_timestamp = candles.first().map(|c| c.timestamp).unwrap_or(0);
198 let end_timestamp = candles.last().map(|c| c.timestamp).unwrap_or(0);
199
200 Ok(BacktestResult {
201 symbol: symbol.to_string(),
202 strategy_name: strategy.name().to_string(),
203 config: self.config.clone(),
204 start_timestamp,
205 end_timestamp,
206 initial_capital: self.config.initial_capital,
207 final_equity,
208 metrics,
209 trades,
210 equity_curve,
211 signals,
212 open_position: position,
213 })
214 }
215
216 fn compute_indicators<S: Strategy>(
218 &self,
219 candles: &[Candle],
220 strategy: &S,
221 ) -> Result<HashMap<String, Vec<Option<f64>>>> {
222 let mut result = HashMap::new();
223
224 let closes: Vec<f64> = candles.iter().map(|c| c.close).collect();
225 let highs: Vec<f64> = candles.iter().map(|c| c.high).collect();
226 let lows: Vec<f64> = candles.iter().map(|c| c.low).collect();
227 let volumes: Vec<f64> = candles.iter().map(|c| c.volume as f64).collect();
228
229 for (name, indicator) in strategy.required_indicators() {
230 match indicator {
231 Indicator::Sma(period) => {
232 let values = indicators::sma(&closes, period);
233 result.insert(name, values);
234 }
235 Indicator::Ema(period) => {
236 let values = indicators::ema(&closes, period);
237 result.insert(name, values);
238 }
239 Indicator::Rsi(period) => {
240 let values = indicators::rsi(&closes, period)?;
241 result.insert(name, values);
242 }
243 Indicator::Macd { fast, slow, signal } => {
244 let macd_result = indicators::macd(&closes, fast, slow, signal)?;
245 result.insert("macd_line".to_string(), macd_result.macd_line);
246 result.insert("macd_signal".to_string(), macd_result.signal_line);
247 result.insert("macd_histogram".to_string(), macd_result.histogram);
248 }
249 Indicator::Bollinger { period, std_dev } => {
250 let bb = indicators::bollinger_bands(&closes, period, std_dev)?;
251 result.insert("bollinger_upper".to_string(), bb.upper);
252 result.insert("bollinger_middle".to_string(), bb.middle);
253 result.insert("bollinger_lower".to_string(), bb.lower);
254 }
255 Indicator::Atr(period) => {
256 let values = indicators::atr(&highs, &lows, &closes, period)?;
257 result.insert(name, values);
258 }
259 Indicator::Supertrend { period, multiplier } => {
260 let st = indicators::supertrend(&highs, &lows, &closes, period, multiplier)?;
261 result.insert("supertrend_value".to_string(), st.value);
262 let uptrend: Vec<Option<f64>> = st
264 .is_uptrend
265 .into_iter()
266 .map(|v| v.map(|b| if b { 1.0 } else { 0.0 }))
267 .collect();
268 result.insert("supertrend_uptrend".to_string(), uptrend);
269 }
270 Indicator::DonchianChannels(period) => {
271 let dc = indicators::donchian_channels(&highs, &lows, period)?;
272 result.insert("donchian_upper".to_string(), dc.upper);
273 result.insert("donchian_middle".to_string(), dc.middle);
274 result.insert("donchian_lower".to_string(), dc.lower);
275 }
276 Indicator::Wma(period) => {
277 let values = indicators::wma(&closes, period)?;
278 result.insert(name, values);
279 }
280 Indicator::Dema(period) => {
281 let values = indicators::dema(&closes, period)?;
282 result.insert(name, values);
283 }
284 Indicator::Tema(period) => {
285 let values = indicators::tema(&closes, period)?;
286 result.insert(name, values);
287 }
288 Indicator::Hma(period) => {
289 let values = indicators::hma(&closes, period)?;
290 result.insert(name, values);
291 }
292 Indicator::Obv => {
293 let values = indicators::obv(&closes, &volumes)?;
294 result.insert(name, values);
295 }
296 Indicator::Momentum(period) => {
297 let values = indicators::momentum(&closes, period)?;
298 result.insert(name, values);
299 }
300 Indicator::Roc(period) => {
301 let values = indicators::roc(&closes, period)?;
302 result.insert(name, values);
303 }
304 Indicator::Cci(period) => {
305 let values = indicators::cci(&highs, &lows, &closes, period)?;
306 result.insert(name, values);
307 }
308 Indicator::WilliamsR(period) => {
309 let values = indicators::williams_r(&highs, &lows, &closes, period)?;
310 result.insert(name, values);
311 }
312 Indicator::Adx(period) => {
313 let values = indicators::adx(&highs, &lows, &closes, period)?;
314 result.insert(name, values);
315 }
316 Indicator::Mfi(period) => {
317 let values = indicators::mfi(&highs, &lows, &closes, &volumes, period)?;
318 result.insert(name, values);
319 }
320 Indicator::Cmf(period) => {
321 let values = indicators::cmf(&highs, &lows, &closes, &volumes, period)?;
322 result.insert(name, values);
323 }
324 Indicator::Cmo(period) => {
325 let values = indicators::cmo(&closes, period)?;
326 result.insert(name, values);
327 }
328 Indicator::Vwma(period) => {
329 let values = indicators::vwma(&closes, &volumes, period)?;
330 result.insert(name, values);
331 }
332 Indicator::Alma {
333 period,
334 offset,
335 sigma,
336 } => {
337 let values = indicators::alma(&closes, period, offset, sigma)?;
338 result.insert(name, values);
339 }
340 Indicator::McginleyDynamic(period) => {
341 let values = indicators::mcginley_dynamic(&closes, period)?;
342 result.insert(name, values);
343 }
344 Indicator::Stochastic {
346 k_period,
347 k_slow: _,
348 d_period,
349 } => {
350 let stoch = indicators::stochastic(&highs, &lows, &closes, k_period, d_period)?;
351 result.insert("stochastic_k".to_string(), stoch.k);
352 result.insert("stochastic_d".to_string(), stoch.d);
353 }
354 Indicator::StochasticRsi {
355 rsi_period,
356 stoch_period,
357 k_period: _,
358 d_period: _,
359 } => {
360 let values = indicators::stochastic_rsi(&closes, rsi_period, stoch_period)?;
361 result.insert(name, values);
362 }
363 Indicator::AwesomeOscillator { fast: _, slow: _ } => {
364 let values = indicators::awesome_oscillator(&highs, &lows)?;
366 result.insert(name, values);
367 }
368 Indicator::CoppockCurve {
369 wma_period: _,
370 long_roc: _,
371 short_roc: _,
372 } => {
373 let values = indicators::coppock_curve(&closes)?;
375 result.insert(name, values);
376 }
377 Indicator::Aroon(period) => {
379 let aroon_result = indicators::aroon(&highs, &lows, period)?;
380 result.insert("aroon_up".to_string(), aroon_result.aroon_up);
381 result.insert("aroon_down".to_string(), aroon_result.aroon_down);
382 }
383 Indicator::Ichimoku {
384 conversion: _,
385 base: _,
386 lagging: _,
387 displacement: _,
388 } => {
389 let ich = indicators::ichimoku(&highs, &lows, &closes)?;
391 result.insert("ichimoku_conversion".to_string(), ich.conversion_line);
392 result.insert("ichimoku_base".to_string(), ich.base_line);
393 result.insert("ichimoku_leading_a".to_string(), ich.leading_span_a);
394 result.insert("ichimoku_leading_b".to_string(), ich.leading_span_b);
395 result.insert("ichimoku_lagging".to_string(), ich.lagging_span);
396 }
397 Indicator::ParabolicSar { step, max } => {
398 let values = indicators::parabolic_sar(&highs, &lows, &closes, step, max)?;
399 result.insert(name, values);
400 }
401 Indicator::KeltnerChannels {
403 period,
404 multiplier,
405 atr_period,
406 } => {
407 let kc = indicators::keltner_channels(
408 &highs, &lows, &closes, period, atr_period, multiplier,
409 )?;
410 result.insert("keltner_upper".to_string(), kc.upper);
411 result.insert("keltner_middle".to_string(), kc.middle);
412 result.insert("keltner_lower".to_string(), kc.lower);
413 }
414 Indicator::TrueRange => {
415 let values = indicators::true_range(&highs, &lows, &closes)?;
416 result.insert(name, values);
417 }
418 Indicator::ChoppinessIndex(period) => {
419 let values = indicators::choppiness_index(&highs, &lows, &closes, period)?;
420 result.insert(name, values);
421 }
422 Indicator::Vwap => {
424 let values = indicators::vwap(&highs, &lows, &closes, &volumes)?;
425 result.insert(name, values);
426 }
427 Indicator::ChaikinOscillator => {
428 let values = indicators::chaikin_oscillator(&highs, &lows, &closes, &volumes)?;
429 result.insert(name, values);
430 }
431 Indicator::AccumulationDistribution => {
432 let values =
433 indicators::accumulation_distribution(&highs, &lows, &closes, &volumes)?;
434 result.insert(name, values);
435 }
436 Indicator::BalanceOfPower(period) => {
437 let opens: Vec<f64> = candles.iter().map(|c| c.open).collect();
438 let values =
439 indicators::balance_of_power(&opens, &highs, &lows, &closes, period)?;
440 result.insert(name, values);
441 }
442 Indicator::BullBearPower(_period) => {
444 let bbp = indicators::bull_bear_power(&highs, &lows, &closes)?;
446 result.insert("bull_power".to_string(), bbp.bull_power);
447 result.insert("bear_power".to_string(), bbp.bear_power);
448 }
449 Indicator::ElderRay(_period) => {
450 let er = indicators::elder_ray(&highs, &lows, &closes)?;
452 result.insert("elder_bull".to_string(), er.bull_power);
453 result.insert("elder_bear".to_string(), er.bear_power);
454 }
455 }
456 }
457
458 Ok(result)
459 }
460
461 fn check_sl_tp(&self, position: &Position, candle: &Candle) -> Option<Signal> {
463 let return_pct = position.unrealized_return_pct(candle.close) / 100.0;
464
465 if let Some(sl_pct) = self.config.stop_loss_pct
467 && return_pct <= -sl_pct
468 {
469 return Some(
470 Signal::exit(candle.timestamp, candle.close)
471 .with_reason(format!("Stop-loss triggered ({:.1}%)", return_pct * 100.0)),
472 );
473 }
474
475 if let Some(tp_pct) = self.config.take_profit_pct
477 && return_pct >= tp_pct
478 {
479 return Some(
480 Signal::exit(candle.timestamp, candle.close).with_reason(format!(
481 "Take-profit triggered ({:.1}%)",
482 return_pct * 100.0
483 )),
484 );
485 }
486
487 None
488 }
489
490 fn execute_signal(
492 &self,
493 signal: &Signal,
494 candle: &Candle,
495 position: &mut Option<Position>,
496 cash: &mut f64,
497 trades: &mut Vec<Trade>,
498 ) -> bool {
499 match signal.direction {
500 SignalDirection::Long => {
501 if position.is_some() {
502 return false; }
504 self.open_position(position, cash, candle, signal, true)
505 }
506 SignalDirection::Short => {
507 if position.is_some() {
508 return false; }
510 if !self.config.allow_short {
511 return false; }
513 self.open_position(position, cash, candle, signal, false)
514 }
515 SignalDirection::Exit => {
516 if position.is_none() {
517 return false; }
519 self.close_position(position, cash, trades, candle, signal)
520 }
521 SignalDirection::Hold => false,
522 }
523 }
524
525 fn open_position(
527 &self,
528 position: &mut Option<Position>,
529 cash: &mut f64,
530 candle: &Candle,
531 signal: &Signal,
532 is_long: bool,
533 ) -> bool {
534 let entry_price = self.config.apply_entry_slippage(candle.close, is_long);
535 let quantity = self.config.calculate_position_size(*cash, entry_price);
536
537 if quantity <= 0.0 {
538 return false; }
540
541 let entry_value = entry_price * quantity;
542 let commission = self.config.calculate_commission(entry_value);
543
544 if entry_value + commission > *cash {
545 return false; }
547
548 let side = if is_long {
549 PositionSide::Long
550 } else {
551 PositionSide::Short
552 };
553
554 *cash -= entry_value + commission;
555 *position = Some(Position::new(
556 side,
557 candle.timestamp,
558 entry_price,
559 quantity,
560 commission,
561 signal.clone(),
562 ));
563
564 true
565 }
566
567 fn close_position(
569 &self,
570 position: &mut Option<Position>,
571 cash: &mut f64,
572 trades: &mut Vec<Trade>,
573 candle: &Candle,
574 signal: &Signal,
575 ) -> bool {
576 let pos = match position.take() {
577 Some(p) => p,
578 None => return false,
579 };
580
581 let exit_price = self.config.apply_exit_slippage(candle.close, pos.is_long());
582 let exit_commission = self.config.calculate_commission(exit_price * pos.quantity);
583
584 let trade = pos.close(
585 candle.timestamp,
586 exit_price,
587 exit_commission,
588 signal.clone(),
589 );
590
591 *cash += trade.entry_value() + trade.pnl;
592 trades.push(trade);
593
594 true
595 }
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601 use crate::backtesting::strategy::SmaCrossover;
602
603 fn make_candles(prices: &[f64]) -> Vec<Candle> {
604 prices
605 .iter()
606 .enumerate()
607 .map(|(i, &p)| Candle {
608 timestamp: i as i64,
609 open: p,
610 high: p * 1.01,
611 low: p * 0.99,
612 close: p,
613 volume: 1000,
614 adj_close: Some(p),
615 })
616 .collect()
617 }
618
619 #[test]
620 fn test_engine_basic() {
621 let mut prices = vec![100.0; 30];
623 for (i, price) in prices.iter_mut().enumerate().take(25).skip(15) {
625 *price = 100.0 + (i - 15) as f64 * 2.0;
626 }
627 for (i, price) in prices.iter_mut().enumerate().take(30).skip(25) {
629 *price = 118.0 - (i - 25) as f64 * 3.0;
630 }
631
632 let candles = make_candles(&prices);
633 let config = BacktestConfig::builder()
634 .initial_capital(10_000.0)
635 .commission_pct(0.0)
636 .slippage_pct(0.0)
637 .build()
638 .unwrap();
639
640 let engine = BacktestEngine::new(config);
641 let strategy = SmaCrossover::new(5, 10);
642 let result = engine.run("TEST", &candles, strategy).unwrap();
643
644 assert_eq!(result.symbol, "TEST");
645 assert_eq!(result.strategy_name, "SMA Crossover");
646 assert!(!result.equity_curve.is_empty());
647 }
648
649 #[test]
650 fn test_stop_loss() {
651 let mut prices = vec![100.0; 20];
653 for (i, price) in prices.iter_mut().enumerate().take(15).skip(10) {
655 *price = 100.0 + (i - 10) as f64 * 2.0;
656 }
657 for (i, price) in prices.iter_mut().enumerate().take(20).skip(15) {
659 *price = 108.0 - (i - 15) as f64 * 10.0;
660 }
661
662 let candles = make_candles(&prices);
663 let config = BacktestConfig::builder()
664 .initial_capital(10_000.0)
665 .stop_loss_pct(0.05) .commission_pct(0.0)
667 .slippage_pct(0.0)
668 .build()
669 .unwrap();
670
671 let engine = BacktestEngine::new(config);
672 let strategy = SmaCrossover::new(3, 6);
673 let result = engine.run("TEST", &candles, strategy).unwrap();
674
675 let _sl_signals: Vec<_> = result
677 .signals
678 .iter()
679 .filter(|s| {
680 s.reason
681 .as_ref()
682 .map(|r| r.contains("Stop-loss"))
683 .unwrap_or(false)
684 })
685 .collect();
686
687 assert!(!result.equity_curve.is_empty());
690 }
691
692 #[test]
693 fn test_insufficient_data() {
694 let candles = make_candles(&[100.0, 101.0, 102.0]); let config = BacktestConfig::default();
696 let engine = BacktestEngine::new(config);
697 let strategy = SmaCrossover::new(10, 20); let result = engine.run("TEST", &candles, strategy);
700 assert!(result.is_err());
701 }
702}