Skip to main content

finance_query/backtesting/
engine.rs

1//! Backtest execution engine.
2
3use std::collections::HashMap;
4
5use crate::indicators::{self, Indicator};
6use crate::models::chart::Candle;
7
8use super::config::BacktestConfig;
9use super::error::{BacktestError, Result};
10use super::position::{Position, PositionSide, Trade};
11use super::result::{BacktestResult, EquityPoint, PerformanceMetrics, SignalRecord};
12use super::signal::{Signal, SignalDirection};
13use super::strategy::{Strategy, StrategyContext};
14
15/// Backtest execution engine.
16///
17/// Handles indicator pre-computation, position management, and trade execution.
18pub struct BacktestEngine {
19    config: BacktestConfig,
20}
21
22impl BacktestEngine {
23    /// Create a new backtest engine with the given configuration
24    pub fn new(config: BacktestConfig) -> Self {
25        Self { config }
26    }
27
28    /// Run a backtest with the given strategy on historical candle data
29    pub fn run<S: Strategy>(
30        &self,
31        symbol: &str,
32        candles: &[Candle],
33        strategy: S,
34    ) -> Result<BacktestResult> {
35        let warmup = strategy.warmup_period();
36        if candles.len() < warmup {
37            return Err(BacktestError::insufficient_data(warmup, candles.len()));
38        }
39
40        // Pre-compute all required indicators
41        let indicators = self.compute_indicators(candles, &strategy)?;
42
43        // Initialize state
44        let mut equity = self.config.initial_capital;
45        let mut cash = self.config.initial_capital;
46        let mut position: Option<Position> = None;
47        let mut trades: Vec<Trade> = Vec::new();
48        let mut equity_curve: Vec<EquityPoint> = Vec::new();
49        let mut signals: Vec<SignalRecord> = Vec::new();
50        let mut peak_equity = equity;
51
52        // Main simulation loop
53        for i in 0..candles.len() {
54            let candle = &candles[i];
55
56            // Update equity with current position value
57            if let Some(ref pos) = position {
58                let pos_value = pos.current_value(candle.close);
59                equity = cash + pos_value;
60            } else {
61                equity = cash;
62            }
63
64            // Track drawdown
65            if equity > peak_equity {
66                peak_equity = equity;
67            }
68            let drawdown_pct = if peak_equity > 0.0 {
69                (peak_equity - equity) / peak_equity
70            } else {
71                0.0
72            };
73
74            equity_curve.push(EquityPoint {
75                timestamp: candle.timestamp,
76                equity,
77                drawdown_pct,
78            });
79
80            // Check stop-loss / take-profit on existing position
81            if let Some(ref pos) = position
82                && let Some(exit_signal) = self.check_sl_tp(pos, candle)
83            {
84                let exit_price = self.config.apply_exit_slippage(candle.close, pos.is_long());
85                let exit_commission = self.config.calculate_commission(exit_price * pos.quantity);
86
87                signals.push(SignalRecord {
88                    timestamp: candle.timestamp,
89                    price: candle.close,
90                    direction: SignalDirection::Exit,
91                    strength: 1.0,
92                    reason: exit_signal.reason.clone(),
93                    executed: true,
94                });
95
96                let trade = position.take().unwrap().close(
97                    candle.timestamp,
98                    exit_price,
99                    exit_commission,
100                    exit_signal,
101                );
102
103                cash += trade.entry_value() + trade.pnl;
104                trades.push(trade);
105                continue; // Skip strategy signal this bar
106            }
107
108            // Skip strategy signals during warmup period
109            if i < warmup.saturating_sub(1) {
110                continue;
111            }
112
113            // Build strategy context
114            let ctx = StrategyContext {
115                candles: &candles[..=i],
116                index: i,
117                position: position.as_ref(),
118                equity,
119                indicators: &indicators,
120            };
121
122            // Get strategy signal
123            let signal = strategy.on_candle(&ctx);
124
125            // Skip hold signals
126            if signal.is_hold() {
127                continue;
128            }
129
130            // Check signal strength threshold
131            if signal.strength.value() < self.config.min_signal_strength {
132                signals.push(SignalRecord {
133                    timestamp: signal.timestamp,
134                    price: signal.price,
135                    direction: signal.direction,
136                    strength: signal.strength.value(),
137                    reason: signal.reason.clone(),
138                    executed: false,
139                });
140                continue;
141            }
142
143            // Record the signal
144            let executed =
145                self.execute_signal(&signal, candle, &mut position, &mut cash, &mut trades);
146
147            signals.push(SignalRecord {
148                timestamp: signal.timestamp,
149                price: signal.price,
150                direction: signal.direction,
151                strength: signal.strength.value(),
152                reason: signal.reason,
153                executed,
154            });
155        }
156
157        // Close any open position at end if configured
158        if self.config.close_at_end
159            && let Some(pos) = position.take()
160        {
161            let last_candle = candles.last().unwrap();
162            let exit_price = self
163                .config
164                .apply_exit_slippage(last_candle.close, pos.is_long());
165            let exit_commission = self.config.calculate_commission(exit_price * pos.quantity);
166
167            let exit_signal = Signal::exit(last_candle.timestamp, last_candle.close)
168                .with_reason("End of backtest");
169
170            let trade = pos.close(
171                last_candle.timestamp,
172                exit_price,
173                exit_commission,
174                exit_signal,
175            );
176            cash += trade.entry_value() + trade.pnl;
177            trades.push(trade);
178        }
179
180        // Final equity
181        let final_equity = if let Some(ref pos) = position {
182            cash + pos.current_value(candles.last().unwrap().close)
183        } else {
184            cash
185        };
186
187        // Calculate metrics
188        let executed_signals = signals.iter().filter(|s| s.executed).count();
189        let metrics = PerformanceMetrics::calculate(
190            &trades,
191            &equity_curve,
192            self.config.initial_capital,
193            signals.len(),
194            executed_signals,
195        );
196
197        let start_timestamp = candles.first().map(|c| c.timestamp).unwrap_or(0);
198        let end_timestamp = candles.last().map(|c| c.timestamp).unwrap_or(0);
199
200        Ok(BacktestResult {
201            symbol: symbol.to_string(),
202            strategy_name: strategy.name().to_string(),
203            config: self.config.clone(),
204            start_timestamp,
205            end_timestamp,
206            initial_capital: self.config.initial_capital,
207            final_equity,
208            metrics,
209            trades,
210            equity_curve,
211            signals,
212            open_position: position,
213        })
214    }
215
216    /// Pre-compute all indicators required by the strategy
217    fn compute_indicators<S: Strategy>(
218        &self,
219        candles: &[Candle],
220        strategy: &S,
221    ) -> Result<HashMap<String, Vec<Option<f64>>>> {
222        let mut result = HashMap::new();
223
224        let closes: Vec<f64> = candles.iter().map(|c| c.close).collect();
225        let highs: Vec<f64> = candles.iter().map(|c| c.high).collect();
226        let lows: Vec<f64> = candles.iter().map(|c| c.low).collect();
227        let volumes: Vec<f64> = candles.iter().map(|c| c.volume as f64).collect();
228
229        for (name, indicator) in strategy.required_indicators() {
230            match indicator {
231                Indicator::Sma(period) => {
232                    let values = indicators::sma(&closes, period);
233                    result.insert(name, values);
234                }
235                Indicator::Ema(period) => {
236                    let values = indicators::ema(&closes, period);
237                    result.insert(name, values);
238                }
239                Indicator::Rsi(period) => {
240                    let values = indicators::rsi(&closes, period)?;
241                    result.insert(name, values);
242                }
243                Indicator::Macd { fast, slow, signal } => {
244                    let macd_result = indicators::macd(&closes, fast, slow, signal)?;
245                    result.insert("macd_line".to_string(), macd_result.macd_line);
246                    result.insert("macd_signal".to_string(), macd_result.signal_line);
247                    result.insert("macd_histogram".to_string(), macd_result.histogram);
248                }
249                Indicator::Bollinger { period, std_dev } => {
250                    let bb = indicators::bollinger_bands(&closes, period, std_dev)?;
251                    result.insert("bollinger_upper".to_string(), bb.upper);
252                    result.insert("bollinger_middle".to_string(), bb.middle);
253                    result.insert("bollinger_lower".to_string(), bb.lower);
254                }
255                Indicator::Atr(period) => {
256                    let values = indicators::atr(&highs, &lows, &closes, period)?;
257                    result.insert(name, values);
258                }
259                Indicator::Supertrend { period, multiplier } => {
260                    let st = indicators::supertrend(&highs, &lows, &closes, period, multiplier)?;
261                    result.insert("supertrend_value".to_string(), st.value);
262                    // Convert bool to f64 for consistency
263                    let uptrend: Vec<Option<f64>> = st
264                        .is_uptrend
265                        .into_iter()
266                        .map(|v| v.map(|b| if b { 1.0 } else { 0.0 }))
267                        .collect();
268                    result.insert("supertrend_uptrend".to_string(), uptrend);
269                }
270                Indicator::DonchianChannels(period) => {
271                    let dc = indicators::donchian_channels(&highs, &lows, period)?;
272                    result.insert("donchian_upper".to_string(), dc.upper);
273                    result.insert("donchian_middle".to_string(), dc.middle);
274                    result.insert("donchian_lower".to_string(), dc.lower);
275                }
276                Indicator::Wma(period) => {
277                    let values = indicators::wma(&closes, period)?;
278                    result.insert(name, values);
279                }
280                Indicator::Dema(period) => {
281                    let values = indicators::dema(&closes, period)?;
282                    result.insert(name, values);
283                }
284                Indicator::Tema(period) => {
285                    let values = indicators::tema(&closes, period)?;
286                    result.insert(name, values);
287                }
288                Indicator::Hma(period) => {
289                    let values = indicators::hma(&closes, period)?;
290                    result.insert(name, values);
291                }
292                Indicator::Obv => {
293                    let values = indicators::obv(&closes, &volumes)?;
294                    result.insert(name, values);
295                }
296                Indicator::Momentum(period) => {
297                    let values = indicators::momentum(&closes, period)?;
298                    result.insert(name, values);
299                }
300                Indicator::Roc(period) => {
301                    let values = indicators::roc(&closes, period)?;
302                    result.insert(name, values);
303                }
304                Indicator::Cci(period) => {
305                    let values = indicators::cci(&highs, &lows, &closes, period)?;
306                    result.insert(name, values);
307                }
308                Indicator::WilliamsR(period) => {
309                    let values = indicators::williams_r(&highs, &lows, &closes, period)?;
310                    result.insert(name, values);
311                }
312                Indicator::Adx(period) => {
313                    let values = indicators::adx(&highs, &lows, &closes, period)?;
314                    result.insert(name, values);
315                }
316                Indicator::Mfi(period) => {
317                    let values = indicators::mfi(&highs, &lows, &closes, &volumes, period)?;
318                    result.insert(name, values);
319                }
320                Indicator::Cmf(period) => {
321                    let values = indicators::cmf(&highs, &lows, &closes, &volumes, period)?;
322                    result.insert(name, values);
323                }
324                Indicator::Cmo(period) => {
325                    let values = indicators::cmo(&closes, period)?;
326                    result.insert(name, values);
327                }
328                Indicator::Vwma(period) => {
329                    let values = indicators::vwma(&closes, &volumes, period)?;
330                    result.insert(name, values);
331                }
332                Indicator::Alma {
333                    period,
334                    offset,
335                    sigma,
336                } => {
337                    let values = indicators::alma(&closes, period, offset, sigma)?;
338                    result.insert(name, values);
339                }
340                Indicator::McginleyDynamic(period) => {
341                    let values = indicators::mcginley_dynamic(&closes, period)?;
342                    result.insert(name, values);
343                }
344                // === OSCILLATORS ===
345                Indicator::Stochastic {
346                    k_period,
347                    k_slow: _,
348                    d_period,
349                } => {
350                    let stoch = indicators::stochastic(&highs, &lows, &closes, k_period, d_period)?;
351                    result.insert("stochastic_k".to_string(), stoch.k);
352                    result.insert("stochastic_d".to_string(), stoch.d);
353                }
354                Indicator::StochasticRsi {
355                    rsi_period,
356                    stoch_period,
357                    k_period: _,
358                    d_period: _,
359                } => {
360                    let values = indicators::stochastic_rsi(&closes, rsi_period, stoch_period)?;
361                    result.insert(name, values);
362                }
363                Indicator::AwesomeOscillator { fast: _, slow: _ } => {
364                    // Note: awesome_oscillator uses default periods (5, 34) internally
365                    let values = indicators::awesome_oscillator(&highs, &lows)?;
366                    result.insert(name, values);
367                }
368                Indicator::CoppockCurve {
369                    wma_period: _,
370                    long_roc: _,
371                    short_roc: _,
372                } => {
373                    // Note: coppock_curve uses default periods internally
374                    let values = indicators::coppock_curve(&closes)?;
375                    result.insert(name, values);
376                }
377                // === TREND INDICATORS ===
378                Indicator::Aroon(period) => {
379                    let aroon_result = indicators::aroon(&highs, &lows, period)?;
380                    result.insert("aroon_up".to_string(), aroon_result.aroon_up);
381                    result.insert("aroon_down".to_string(), aroon_result.aroon_down);
382                }
383                Indicator::Ichimoku {
384                    conversion: _,
385                    base: _,
386                    lagging: _,
387                    displacement: _,
388                } => {
389                    // Note: ichimoku uses default periods (9, 26, 52, 26) internally
390                    let ich = indicators::ichimoku(&highs, &lows, &closes)?;
391                    result.insert("ichimoku_conversion".to_string(), ich.conversion_line);
392                    result.insert("ichimoku_base".to_string(), ich.base_line);
393                    result.insert("ichimoku_leading_a".to_string(), ich.leading_span_a);
394                    result.insert("ichimoku_leading_b".to_string(), ich.leading_span_b);
395                    result.insert("ichimoku_lagging".to_string(), ich.lagging_span);
396                }
397                Indicator::ParabolicSar { step, max } => {
398                    let values = indicators::parabolic_sar(&highs, &lows, &closes, step, max)?;
399                    result.insert(name, values);
400                }
401                // === VOLATILITY ===
402                Indicator::KeltnerChannels {
403                    period,
404                    multiplier,
405                    atr_period,
406                } => {
407                    let kc = indicators::keltner_channels(
408                        &highs, &lows, &closes, period, atr_period, multiplier,
409                    )?;
410                    result.insert("keltner_upper".to_string(), kc.upper);
411                    result.insert("keltner_middle".to_string(), kc.middle);
412                    result.insert("keltner_lower".to_string(), kc.lower);
413                }
414                Indicator::TrueRange => {
415                    let values = indicators::true_range(&highs, &lows, &closes)?;
416                    result.insert(name, values);
417                }
418                Indicator::ChoppinessIndex(period) => {
419                    let values = indicators::choppiness_index(&highs, &lows, &closes, period)?;
420                    result.insert(name, values);
421                }
422                // === VOLUME INDICATORS ===
423                Indicator::Vwap => {
424                    let values = indicators::vwap(&highs, &lows, &closes, &volumes)?;
425                    result.insert(name, values);
426                }
427                Indicator::ChaikinOscillator => {
428                    let values = indicators::chaikin_oscillator(&highs, &lows, &closes, &volumes)?;
429                    result.insert(name, values);
430                }
431                Indicator::AccumulationDistribution => {
432                    let values =
433                        indicators::accumulation_distribution(&highs, &lows, &closes, &volumes)?;
434                    result.insert(name, values);
435                }
436                Indicator::BalanceOfPower(period) => {
437                    let opens: Vec<f64> = candles.iter().map(|c| c.open).collect();
438                    let values =
439                        indicators::balance_of_power(&opens, &highs, &lows, &closes, period)?;
440                    result.insert(name, values);
441                }
442                // === POWER/STRENGTH INDICATORS ===
443                Indicator::BullBearPower(_period) => {
444                    // Note: bull_bear_power uses default EMA period (13) internally
445                    let bbp = indicators::bull_bear_power(&highs, &lows, &closes)?;
446                    result.insert("bull_power".to_string(), bbp.bull_power);
447                    result.insert("bear_power".to_string(), bbp.bear_power);
448                }
449                Indicator::ElderRay(_period) => {
450                    // Note: elder_ray uses default EMA period (13) internally
451                    let er = indicators::elder_ray(&highs, &lows, &closes)?;
452                    result.insert("elder_bull".to_string(), er.bull_power);
453                    result.insert("elder_bear".to_string(), er.bear_power);
454                }
455            }
456        }
457
458        Ok(result)
459    }
460
461    /// Check if stop-loss or take-profit should trigger
462    fn check_sl_tp(&self, position: &Position, candle: &Candle) -> Option<Signal> {
463        let return_pct = position.unrealized_return_pct(candle.close) / 100.0;
464
465        // Check stop-loss
466        if let Some(sl_pct) = self.config.stop_loss_pct
467            && return_pct <= -sl_pct
468        {
469            return Some(
470                Signal::exit(candle.timestamp, candle.close)
471                    .with_reason(format!("Stop-loss triggered ({:.1}%)", return_pct * 100.0)),
472            );
473        }
474
475        // Check take-profit
476        if let Some(tp_pct) = self.config.take_profit_pct
477            && return_pct >= tp_pct
478        {
479            return Some(
480                Signal::exit(candle.timestamp, candle.close).with_reason(format!(
481                    "Take-profit triggered ({:.1}%)",
482                    return_pct * 100.0
483                )),
484            );
485        }
486
487        None
488    }
489
490    /// Execute a signal, modifying position and cash
491    fn execute_signal(
492        &self,
493        signal: &Signal,
494        candle: &Candle,
495        position: &mut Option<Position>,
496        cash: &mut f64,
497        trades: &mut Vec<Trade>,
498    ) -> bool {
499        match signal.direction {
500            SignalDirection::Long => {
501                if position.is_some() {
502                    return false; // Already have a position
503                }
504                self.open_position(position, cash, candle, signal, true)
505            }
506            SignalDirection::Short => {
507                if position.is_some() {
508                    return false; // Already have a position
509                }
510                if !self.config.allow_short {
511                    return false; // Shorts not allowed
512                }
513                self.open_position(position, cash, candle, signal, false)
514            }
515            SignalDirection::Exit => {
516                if position.is_none() {
517                    return false; // No position to exit
518                }
519                self.close_position(position, cash, trades, candle, signal)
520            }
521            SignalDirection::Hold => false,
522        }
523    }
524
525    /// Open a new position
526    fn open_position(
527        &self,
528        position: &mut Option<Position>,
529        cash: &mut f64,
530        candle: &Candle,
531        signal: &Signal,
532        is_long: bool,
533    ) -> bool {
534        let entry_price = self.config.apply_entry_slippage(candle.close, is_long);
535        let quantity = self.config.calculate_position_size(*cash, entry_price);
536
537        if quantity <= 0.0 {
538            return false; // Not enough capital
539        }
540
541        let entry_value = entry_price * quantity;
542        let commission = self.config.calculate_commission(entry_value);
543
544        if entry_value + commission > *cash {
545            return false; // Not enough capital including commission
546        }
547
548        let side = if is_long {
549            PositionSide::Long
550        } else {
551            PositionSide::Short
552        };
553
554        *cash -= entry_value + commission;
555        *position = Some(Position::new(
556            side,
557            candle.timestamp,
558            entry_price,
559            quantity,
560            commission,
561            signal.clone(),
562        ));
563
564        true
565    }
566
567    /// Close an existing position
568    fn close_position(
569        &self,
570        position: &mut Option<Position>,
571        cash: &mut f64,
572        trades: &mut Vec<Trade>,
573        candle: &Candle,
574        signal: &Signal,
575    ) -> bool {
576        let pos = match position.take() {
577            Some(p) => p,
578            None => return false,
579        };
580
581        let exit_price = self.config.apply_exit_slippage(candle.close, pos.is_long());
582        let exit_commission = self.config.calculate_commission(exit_price * pos.quantity);
583
584        let trade = pos.close(
585            candle.timestamp,
586            exit_price,
587            exit_commission,
588            signal.clone(),
589        );
590
591        *cash += trade.entry_value() + trade.pnl;
592        trades.push(trade);
593
594        true
595    }
596}
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601    use crate::backtesting::strategy::SmaCrossover;
602
603    fn make_candles(prices: &[f64]) -> Vec<Candle> {
604        prices
605            .iter()
606            .enumerate()
607            .map(|(i, &p)| Candle {
608                timestamp: i as i64,
609                open: p,
610                high: p * 1.01,
611                low: p * 0.99,
612                close: p,
613                volume: 1000,
614                adj_close: Some(p),
615            })
616            .collect()
617    }
618
619    #[test]
620    fn test_engine_basic() {
621        // Price trends up then down - should trigger crossover signals
622        let mut prices = vec![100.0; 30];
623        // Make fast SMA cross above slow SMA around bar 15
624        for (i, price) in prices.iter_mut().enumerate().take(25).skip(15) {
625            *price = 100.0 + (i - 15) as f64 * 2.0;
626        }
627        // Then cross back down
628        for (i, price) in prices.iter_mut().enumerate().take(30).skip(25) {
629            *price = 118.0 - (i - 25) as f64 * 3.0;
630        }
631
632        let candles = make_candles(&prices);
633        let config = BacktestConfig::builder()
634            .initial_capital(10_000.0)
635            .commission_pct(0.0)
636            .slippage_pct(0.0)
637            .build()
638            .unwrap();
639
640        let engine = BacktestEngine::new(config);
641        let strategy = SmaCrossover::new(5, 10);
642        let result = engine.run("TEST", &candles, strategy).unwrap();
643
644        assert_eq!(result.symbol, "TEST");
645        assert_eq!(result.strategy_name, "SMA Crossover");
646        assert!(!result.equity_curve.is_empty());
647    }
648
649    #[test]
650    fn test_stop_loss() {
651        // Price drops significantly after entry
652        let mut prices = vec![100.0; 20];
653        // Trend up to trigger long entry
654        for (i, price) in prices.iter_mut().enumerate().take(15).skip(10) {
655            *price = 100.0 + (i - 10) as f64 * 2.0;
656        }
657        // Then crash
658        for (i, price) in prices.iter_mut().enumerate().take(20).skip(15) {
659            *price = 108.0 - (i - 15) as f64 * 10.0;
660        }
661
662        let candles = make_candles(&prices);
663        let config = BacktestConfig::builder()
664            .initial_capital(10_000.0)
665            .stop_loss_pct(0.05) // 5% stop loss
666            .commission_pct(0.0)
667            .slippage_pct(0.0)
668            .build()
669            .unwrap();
670
671        let engine = BacktestEngine::new(config);
672        let strategy = SmaCrossover::new(3, 6);
673        let result = engine.run("TEST", &candles, strategy).unwrap();
674
675        // Should have triggered stop-loss
676        let _sl_signals: Vec<_> = result
677            .signals
678            .iter()
679            .filter(|s| {
680                s.reason
681                    .as_ref()
682                    .map(|r| r.contains("Stop-loss"))
683                    .unwrap_or(false)
684            })
685            .collect();
686
687        // May or may not trigger depending on exact timing
688        // The important thing is the engine doesn't crash
689        assert!(!result.equity_curve.is_empty());
690    }
691
692    #[test]
693    fn test_insufficient_data() {
694        let candles = make_candles(&[100.0, 101.0, 102.0]); // Only 3 candles
695        let config = BacktestConfig::default();
696        let engine = BacktestEngine::new(config);
697        let strategy = SmaCrossover::new(10, 20); // Needs at least 21 candles
698
699        let result = engine.run("TEST", &candles, strategy);
700        assert!(result.is_err());
701    }
702}