Skip to main content

finance_query/backtesting/strategy/
prebuilt.rs

1//! Pre-built trading strategies.
2//!
3//! Ready-to-use strategy implementations that can be used directly with the backtest engine.
4//! Each strategy implements the [`Strategy`] trait and can be customized via builder methods.
5//!
6//! # Available Strategies
7//!
8//! | Strategy | Description |
9//! |----------|-------------|
10//! | [`SmaCrossover`] | Dual SMA crossover (trend following) |
11//! | [`RsiReversal`] | RSI mean reversion |
12//! | [`MacdSignal`] | MACD line crossover |
13//! | [`BollingerMeanReversion`] | Bollinger Bands mean reversion |
14//! | [`SuperTrendFollow`] | SuperTrend trend following |
15//! | [`DonchianBreakout`] | Donchian channel breakout |
16//!
17//! # Example
18//!
19//! ```ignore
20//! use finance_query::backtesting::{SmaCrossover, BacktestConfig};
21//!
22//! // Use a pre-built strategy directly
23//! let strategy = SmaCrossover::new(10, 20).with_short(true);
24//! ```
25
26use crate::indicators::Indicator;
27
28use super::{Signal, Strategy, StrategyContext};
29use crate::backtesting::signal::SignalStrength;
30
31/// SMA Crossover Strategy
32///
33/// Goes long when fast SMA crosses above slow SMA.
34/// Exits when fast SMA crosses below slow SMA.
35/// Optionally goes short on bearish crossovers.
36#[derive(Debug, Clone)]
37pub struct SmaCrossover {
38    /// Fast SMA period
39    pub fast_period: usize,
40    /// Slow SMA period
41    pub slow_period: usize,
42    /// Allow short positions
43    pub allow_short: bool,
44}
45
46impl SmaCrossover {
47    /// Create a new SMA crossover strategy
48    pub fn new(fast_period: usize, slow_period: usize) -> Self {
49        Self {
50            fast_period,
51            slow_period,
52            allow_short: false,
53        }
54    }
55
56    /// Enable short positions on bearish crossovers
57    pub fn with_short(mut self, allow: bool) -> Self {
58        self.allow_short = allow;
59        self
60    }
61
62    fn fast_key(&self) -> String {
63        format!("sma_{}", self.fast_period)
64    }
65
66    fn slow_key(&self) -> String {
67        format!("sma_{}", self.slow_period)
68    }
69}
70
71impl Default for SmaCrossover {
72    fn default() -> Self {
73        Self::new(10, 20)
74    }
75}
76
77impl Strategy for SmaCrossover {
78    fn name(&self) -> &str {
79        "SMA Crossover"
80    }
81
82    fn required_indicators(&self) -> Vec<(String, Indicator)> {
83        vec![
84            (self.fast_key(), Indicator::Sma(self.fast_period)),
85            (self.slow_key(), Indicator::Sma(self.slow_period)),
86        ]
87    }
88
89    fn warmup_period(&self) -> usize {
90        self.slow_period.max(self.fast_period) + 1
91    }
92
93    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
94        let candle = ctx.current_candle();
95
96        // Bullish crossover: fast crosses above slow
97        if ctx.crossed_above(&self.fast_key(), &self.slow_key()) {
98            if ctx.is_short() {
99                return Signal::exit(candle.timestamp, candle.close)
100                    .with_reason("SMA bullish crossover - close short");
101            }
102            if !ctx.has_position() {
103                return Signal::long(candle.timestamp, candle.close)
104                    .with_reason("SMA bullish crossover");
105            }
106        }
107
108        // Bearish crossover: fast crosses below slow
109        if ctx.crossed_below(&self.fast_key(), &self.slow_key()) {
110            if ctx.is_long() {
111                return Signal::exit(candle.timestamp, candle.close)
112                    .with_reason("SMA bearish crossover - close long");
113            }
114            if !ctx.has_position() && self.allow_short {
115                return Signal::short(candle.timestamp, candle.close)
116                    .with_reason("SMA bearish crossover");
117            }
118        }
119
120        Signal::hold()
121    }
122}
123
124/// RSI Reversal Strategy
125///
126/// Goes long when RSI crosses above oversold level.
127/// Exits when RSI reaches overbought level.
128/// Optionally goes short when RSI crosses below overbought level.
129#[derive(Debug, Clone)]
130pub struct RsiReversal {
131    /// RSI period
132    pub period: usize,
133    /// Oversold threshold (default 30)
134    pub oversold: f64,
135    /// Overbought threshold (default 70)
136    pub overbought: f64,
137    /// Allow short positions
138    pub allow_short: bool,
139}
140
141impl RsiReversal {
142    /// Create a new RSI reversal strategy
143    pub fn new(period: usize) -> Self {
144        Self {
145            period,
146            oversold: 30.0,
147            overbought: 70.0,
148            allow_short: false,
149        }
150    }
151
152    /// Set custom oversold/overbought thresholds
153    pub fn with_thresholds(mut self, oversold: f64, overbought: f64) -> Self {
154        self.oversold = oversold;
155        self.overbought = overbought;
156        self
157    }
158
159    /// Enable short positions
160    pub fn with_short(mut self, allow: bool) -> Self {
161        self.allow_short = allow;
162        self
163    }
164
165    fn rsi_key(&self) -> String {
166        format!("rsi_{}", self.period)
167    }
168}
169
170impl Default for RsiReversal {
171    fn default() -> Self {
172        Self::new(14)
173    }
174}
175
176impl Strategy for RsiReversal {
177    fn name(&self) -> &str {
178        "RSI Reversal"
179    }
180
181    fn required_indicators(&self) -> Vec<(String, Indicator)> {
182        vec![(self.rsi_key(), Indicator::Rsi(self.period))]
183    }
184
185    fn warmup_period(&self) -> usize {
186        self.period + 1
187    }
188
189    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
190        let candle = ctx.current_candle();
191        let rsi = ctx.indicator(&self.rsi_key());
192
193        let Some(rsi_val) = rsi else {
194            return Signal::hold();
195        };
196
197        // Calculate signal strength based on RSI extremity
198        let strength = if !(20.0..=80.0).contains(&rsi_val) {
199            SignalStrength::strong()
200        } else if !(25.0..=75.0).contains(&rsi_val) {
201            SignalStrength::medium()
202        } else {
203            SignalStrength::weak()
204        };
205
206        // Bullish: RSI crosses above oversold
207        if ctx.indicator_crossed_above(&self.rsi_key(), self.oversold) {
208            if ctx.is_short() {
209                return Signal::exit(candle.timestamp, candle.close)
210                    .with_strength(strength)
211                    .with_reason(format!(
212                        "RSI crossed above {:.0} - close short",
213                        self.oversold
214                    ));
215            }
216            if !ctx.has_position() {
217                return Signal::long(candle.timestamp, candle.close)
218                    .with_strength(strength)
219                    .with_reason(format!("RSI crossed above {:.0}", self.oversold));
220            }
221        }
222
223        // Bearish: RSI crosses below overbought
224        if ctx.indicator_crossed_below(&self.rsi_key(), self.overbought) {
225            if ctx.is_long() {
226                return Signal::exit(candle.timestamp, candle.close)
227                    .with_strength(strength)
228                    .with_reason(format!(
229                        "RSI crossed below {:.0} - close long",
230                        self.overbought
231                    ));
232            }
233            if !ctx.has_position() && self.allow_short {
234                return Signal::short(candle.timestamp, candle.close)
235                    .with_strength(strength)
236                    .with_reason(format!("RSI crossed below {:.0}", self.overbought));
237            }
238        }
239
240        Signal::hold()
241    }
242}
243
244/// MACD Signal Strategy
245///
246/// Goes long when MACD line crosses above signal line.
247/// Exits when MACD line crosses below signal line.
248/// Optionally goes short on bearish crossovers.
249#[derive(Debug, Clone)]
250pub struct MacdSignal {
251    /// Fast EMA period
252    pub fast: usize,
253    /// Slow EMA period
254    pub slow: usize,
255    /// Signal line period
256    pub signal: usize,
257    /// Allow short positions
258    pub allow_short: bool,
259}
260
261impl MacdSignal {
262    /// Create a new MACD signal strategy
263    pub fn new(fast: usize, slow: usize, signal: usize) -> Self {
264        Self {
265            fast,
266            slow,
267            signal,
268            allow_short: false,
269        }
270    }
271
272    /// Enable short positions
273    pub fn with_short(mut self, allow: bool) -> Self {
274        self.allow_short = allow;
275        self
276    }
277}
278
279impl Default for MacdSignal {
280    fn default() -> Self {
281        Self::new(12, 26, 9)
282    }
283}
284
285impl Strategy for MacdSignal {
286    fn name(&self) -> &str {
287        "MACD Signal"
288    }
289
290    fn required_indicators(&self) -> Vec<(String, Indicator)> {
291        vec![(
292            "macd".to_string(),
293            Indicator::Macd {
294                fast: self.fast,
295                slow: self.slow,
296                signal: self.signal,
297            },
298        )]
299    }
300
301    fn warmup_period(&self) -> usize {
302        self.slow + self.signal
303    }
304
305    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
306        let candle = ctx.current_candle();
307
308        // MACD line and signal line are stored separately by the engine
309        // Bullish crossover
310        if ctx.crossed_above("macd_line", "macd_signal") {
311            if ctx.is_short() {
312                return Signal::exit(candle.timestamp, candle.close)
313                    .with_reason("MACD bullish crossover - close short");
314            }
315            if !ctx.has_position() {
316                return Signal::long(candle.timestamp, candle.close)
317                    .with_reason("MACD bullish crossover");
318            }
319        }
320
321        // Bearish crossover
322        if ctx.crossed_below("macd_line", "macd_signal") {
323            if ctx.is_long() {
324                return Signal::exit(candle.timestamp, candle.close)
325                    .with_reason("MACD bearish crossover - close long");
326            }
327            if !ctx.has_position() && self.allow_short {
328                return Signal::short(candle.timestamp, candle.close)
329                    .with_reason("MACD bearish crossover");
330            }
331        }
332
333        Signal::hold()
334    }
335}
336
337/// Bollinger Bands Mean Reversion Strategy
338///
339/// Goes long when price touches lower band (oversold).
340/// Exits when price reaches middle or upper band.
341/// Optionally goes short when price touches upper band.
342#[derive(Debug, Clone)]
343pub struct BollingerMeanReversion {
344    /// SMA period for middle band
345    pub period: usize,
346    /// Standard deviation multiplier
347    pub std_dev: f64,
348    /// Allow short positions
349    pub allow_short: bool,
350    /// Exit at middle band (true) or upper/lower band (false)
351    pub exit_at_middle: bool,
352}
353
354impl BollingerMeanReversion {
355    /// Create a new Bollinger mean reversion strategy
356    pub fn new(period: usize, std_dev: f64) -> Self {
357        Self {
358            period,
359            std_dev,
360            allow_short: false,
361            exit_at_middle: true,
362        }
363    }
364
365    /// Enable short positions
366    pub fn with_short(mut self, allow: bool) -> Self {
367        self.allow_short = allow;
368        self
369    }
370
371    /// Set exit target (middle band or opposite band)
372    pub fn exit_at_middle(mut self, at_middle: bool) -> Self {
373        self.exit_at_middle = at_middle;
374        self
375    }
376}
377
378impl Default for BollingerMeanReversion {
379    fn default() -> Self {
380        Self::new(20, 2.0)
381    }
382}
383
384impl Strategy for BollingerMeanReversion {
385    fn name(&self) -> &str {
386        "Bollinger Mean Reversion"
387    }
388
389    fn required_indicators(&self) -> Vec<(String, Indicator)> {
390        vec![(
391            "bollinger".to_string(),
392            Indicator::Bollinger {
393                period: self.period,
394                std_dev: self.std_dev,
395            },
396        )]
397    }
398
399    fn warmup_period(&self) -> usize {
400        self.period
401    }
402
403    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
404        let candle = ctx.current_candle();
405        let close = candle.close;
406
407        let lower = ctx.indicator("bollinger_lower");
408        let middle = ctx.indicator("bollinger_middle");
409        let upper = ctx.indicator("bollinger_upper");
410
411        let (Some(lower_val), Some(middle_val), Some(upper_val)) = (lower, middle, upper) else {
412            return Signal::hold();
413        };
414
415        // Long entry: price at or below lower band
416        if close <= lower_val && !ctx.has_position() {
417            return Signal::long(candle.timestamp, close)
418                .with_reason("Price at lower Bollinger Band");
419        }
420
421        // Long exit
422        if ctx.is_long() {
423            let exit_level = if self.exit_at_middle {
424                middle_val
425            } else {
426                upper_val
427            };
428            if close >= exit_level {
429                return Signal::exit(candle.timestamp, close).with_reason(format!(
430                    "Price reached {} Bollinger Band",
431                    if self.exit_at_middle {
432                        "middle"
433                    } else {
434                        "upper"
435                    }
436                ));
437            }
438        }
439
440        // Short entry: price at or above upper band
441        if close >= upper_val && !ctx.has_position() && self.allow_short {
442            return Signal::short(candle.timestamp, close)
443                .with_reason("Price at upper Bollinger Band");
444        }
445
446        // Short exit
447        if ctx.is_short() {
448            let exit_level = if self.exit_at_middle {
449                middle_val
450            } else {
451                lower_val
452            };
453            if close <= exit_level {
454                return Signal::exit(candle.timestamp, close).with_reason(format!(
455                    "Price reached {} Bollinger Band",
456                    if self.exit_at_middle {
457                        "middle"
458                    } else {
459                        "lower"
460                    }
461                ));
462            }
463        }
464
465        Signal::hold()
466    }
467}
468
469/// SuperTrend Following Strategy
470///
471/// Goes long when SuperTrend turns bullish (uptrend).
472/// Goes short when SuperTrend turns bearish (downtrend).
473#[derive(Debug, Clone)]
474pub struct SuperTrendFollow {
475    /// ATR period
476    pub period: usize,
477    /// ATR multiplier
478    pub multiplier: f64,
479    /// Allow short positions
480    pub allow_short: bool,
481}
482
483impl SuperTrendFollow {
484    /// Create a new SuperTrend following strategy
485    pub fn new(period: usize, multiplier: f64) -> Self {
486        Self {
487            period,
488            multiplier,
489            allow_short: false,
490        }
491    }
492
493    /// Enable short positions
494    pub fn with_short(mut self, allow: bool) -> Self {
495        self.allow_short = allow;
496        self
497    }
498}
499
500impl Default for SuperTrendFollow {
501    fn default() -> Self {
502        Self::new(10, 3.0)
503    }
504}
505
506impl Strategy for SuperTrendFollow {
507    fn name(&self) -> &str {
508        "SuperTrend Follow"
509    }
510
511    fn required_indicators(&self) -> Vec<(String, Indicator)> {
512        vec![(
513            "supertrend".to_string(),
514            Indicator::Supertrend {
515                period: self.period,
516                multiplier: self.multiplier,
517            },
518        )]
519    }
520
521    fn warmup_period(&self) -> usize {
522        self.period + 1
523    }
524
525    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
526        let candle = ctx.current_candle();
527
528        // SuperTrend uptrend stored as 1.0, downtrend as 0.0
529        let trend_now = ctx.indicator("supertrend_uptrend");
530        let trend_prev = ctx.indicator_prev("supertrend_uptrend");
531
532        let (Some(now), Some(prev)) = (trend_now, trend_prev) else {
533            return Signal::hold();
534        };
535
536        let is_uptrend = now > 0.5;
537        let was_uptrend = prev > 0.5;
538
539        // Trend changed to bullish
540        if is_uptrend && !was_uptrend {
541            if ctx.is_short() {
542                return Signal::exit(candle.timestamp, candle.close)
543                    .with_reason("SuperTrend turned bullish - close short");
544            }
545            if !ctx.has_position() {
546                return Signal::long(candle.timestamp, candle.close)
547                    .with_reason("SuperTrend turned bullish");
548            }
549        }
550
551        // Trend changed to bearish
552        if !is_uptrend && was_uptrend {
553            if ctx.is_long() {
554                return Signal::exit(candle.timestamp, candle.close)
555                    .with_reason("SuperTrend turned bearish - close long");
556            }
557            if !ctx.has_position() && self.allow_short {
558                return Signal::short(candle.timestamp, candle.close)
559                    .with_reason("SuperTrend turned bearish");
560            }
561        }
562
563        Signal::hold()
564    }
565}
566
567/// Donchian Channel Breakout Strategy
568///
569/// Goes long when price breaks above upper channel (new high).
570/// Exits when price breaks below lower channel (new low).
571/// Optionally goes short on downward breakouts.
572#[derive(Debug, Clone)]
573pub struct DonchianBreakout {
574    /// Channel period
575    pub period: usize,
576    /// Allow short positions
577    pub allow_short: bool,
578    /// Use middle channel for exit (true) or opposite channel (false)
579    pub exit_at_middle: bool,
580}
581
582impl DonchianBreakout {
583    /// Create a new Donchian breakout strategy
584    pub fn new(period: usize) -> Self {
585        Self {
586            period,
587            allow_short: false,
588            exit_at_middle: true,
589        }
590    }
591
592    /// Enable short positions
593    pub fn with_short(mut self, allow: bool) -> Self {
594        self.allow_short = allow;
595        self
596    }
597
598    /// Set exit at middle channel
599    pub fn exit_at_middle(mut self, at_middle: bool) -> Self {
600        self.exit_at_middle = at_middle;
601        self
602    }
603}
604
605impl Default for DonchianBreakout {
606    fn default() -> Self {
607        Self::new(20)
608    }
609}
610
611impl Strategy for DonchianBreakout {
612    fn name(&self) -> &str {
613        "Donchian Breakout"
614    }
615
616    fn required_indicators(&self) -> Vec<(String, Indicator)> {
617        vec![(
618            "donchian".to_string(),
619            Indicator::DonchianChannels(self.period),
620        )]
621    }
622
623    fn warmup_period(&self) -> usize {
624        self.period
625    }
626
627    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
628        let candle = ctx.current_candle();
629        let close = candle.close;
630
631        let upper = ctx.indicator("donchian_upper");
632        let middle = ctx.indicator("donchian_middle");
633        let lower = ctx.indicator("donchian_lower");
634        let prev_upper = ctx.indicator_prev("donchian_upper");
635        let prev_lower = ctx.indicator_prev("donchian_lower");
636
637        let (Some(_upper_val), Some(middle_val), Some(_lower_val)) = (upper, middle, lower) else {
638            return Signal::hold();
639        };
640
641        // Breakout above previous upper channel -> go long
642        if let Some(prev_up) = prev_upper
643            && close > prev_up
644            && !ctx.has_position()
645        {
646            return Signal::long(candle.timestamp, close)
647                .with_reason("Donchian upper channel breakout");
648        }
649
650        // Breakout below previous lower channel
651        if let Some(prev_low) = prev_lower
652            && close < prev_low
653        {
654            if ctx.is_long() {
655                return Signal::exit(candle.timestamp, close)
656                    .with_reason("Donchian lower channel breakdown - close long");
657            }
658            if !ctx.has_position() && self.allow_short {
659                return Signal::short(candle.timestamp, close)
660                    .with_reason("Donchian lower channel breakdown");
661            }
662        }
663
664        // Exit long at middle
665        if ctx.is_long() && self.exit_at_middle && close <= middle_val {
666            return Signal::exit(candle.timestamp, close)
667                .with_reason("Price reached Donchian middle channel");
668        }
669
670        // Exit short at middle
671        if ctx.is_short() && self.exit_at_middle && close >= middle_val {
672            return Signal::exit(candle.timestamp, close)
673                .with_reason("Price reached Donchian middle channel");
674        }
675
676        Signal::hold()
677    }
678}
679
680#[cfg(test)]
681mod tests {
682    use super::*;
683
684    #[test]
685    fn test_sma_crossover_default() {
686        let s = SmaCrossover::default();
687        assert_eq!(s.fast_period, 10);
688        assert_eq!(s.slow_period, 20);
689        assert!(!s.allow_short);
690    }
691
692    #[test]
693    fn test_sma_crossover_with_short() {
694        let s = SmaCrossover::new(5, 15).with_short(true);
695        assert_eq!(s.fast_period, 5);
696        assert_eq!(s.slow_period, 15);
697        assert!(s.allow_short);
698    }
699
700    #[test]
701    fn test_rsi_default() {
702        let s = RsiReversal::default();
703        assert_eq!(s.period, 14);
704        assert!((s.oversold - 30.0).abs() < 0.01);
705        assert!((s.overbought - 70.0).abs() < 0.01);
706    }
707
708    #[test]
709    fn test_rsi_with_thresholds() {
710        let s = RsiReversal::new(10).with_thresholds(25.0, 75.0);
711        assert_eq!(s.period, 10);
712        assert!((s.oversold - 25.0).abs() < 0.01);
713        assert!((s.overbought - 75.0).abs() < 0.01);
714    }
715
716    #[test]
717    fn test_macd_default() {
718        let s = MacdSignal::default();
719        assert_eq!(s.fast, 12);
720        assert_eq!(s.slow, 26);
721        assert_eq!(s.signal, 9);
722    }
723
724    #[test]
725    fn test_bollinger_default() {
726        let s = BollingerMeanReversion::default();
727        assert_eq!(s.period, 20);
728        assert!((s.std_dev - 2.0).abs() < 0.01);
729    }
730
731    #[test]
732    fn test_supertrend_default() {
733        let s = SuperTrendFollow::default();
734        assert_eq!(s.period, 10);
735        assert!((s.multiplier - 3.0).abs() < 0.01);
736    }
737
738    #[test]
739    fn test_donchian_default() {
740        let s = DonchianBreakout::default();
741        assert_eq!(s.period, 20);
742        assert!(s.exit_at_middle);
743    }
744
745    #[test]
746    fn test_strategy_names() {
747        assert_eq!(SmaCrossover::default().name(), "SMA Crossover");
748        assert_eq!(RsiReversal::default().name(), "RSI Reversal");
749        assert_eq!(MacdSignal::default().name(), "MACD Signal");
750        assert_eq!(
751            BollingerMeanReversion::default().name(),
752            "Bollinger Mean Reversion"
753        );
754        assert_eq!(SuperTrendFollow::default().name(), "SuperTrend Follow");
755        assert_eq!(DonchianBreakout::default().name(), "Donchian Breakout");
756    }
757
758    #[test]
759    fn test_required_indicators() {
760        let sma = SmaCrossover::new(5, 10);
761        let indicators = sma.required_indicators();
762        assert_eq!(indicators.len(), 2);
763        assert_eq!(indicators[0].0, "sma_5");
764        assert_eq!(indicators[1].0, "sma_10");
765
766        let rsi = RsiReversal::new(14);
767        let indicators = rsi.required_indicators();
768        assert_eq!(indicators.len(), 1);
769        assert_eq!(indicators[0].0, "rsi_14");
770    }
771}