Skip to main content

finance_query/backtesting/strategy/
prebuilt.rs

1//! Pre-built trading strategies.
2//!
3//! Ready-to-use strategy implementations that can be used directly with the backtest engine.
4//! Each strategy implements the [`Strategy`] trait and can be customized via builder methods.
5//!
6//! Short signals are always emitted when the condition is met. Whether they are
7//! *executed* is controlled solely by [`BacktestConfig::allow_short`](crate::backtesting::BacktestConfig).
8//!
9//! # Available Strategies
10//!
11//! | Strategy | Description |
12//! |----------|-------------|
13//! | [`SmaCrossover`] | Dual SMA crossover (trend following) |
14//! | [`RsiReversal`] | RSI mean reversion |
15//! | [`MacdSignal`] | MACD line crossover |
16//! | [`BollingerMeanReversion`] | Bollinger Bands mean reversion |
17//! | [`SuperTrendFollow`] | SuperTrend trend following |
18//! | [`DonchianBreakout`] | Donchian channel breakout |
19//!
20//! # Example
21//!
22//! ```ignore
23//! use finance_query::backtesting::{SmaCrossover, BacktestConfig};
24//!
25//! let strategy = SmaCrossover::new(10, 20);
26//! let config = BacktestConfig::builder().allow_short(true).build().unwrap();
27//! ```
28
29use crate::indicators::Indicator;
30
31use super::{Signal, Strategy, StrategyContext};
32use crate::backtesting::signal::SignalStrength;
33
34/// SMA Crossover Strategy
35///
36/// Goes long when fast SMA crosses above slow SMA.
37/// Exits when fast SMA crosses below slow SMA.
38/// Emits short signals on bearish crossovers (execution gated by
39/// [`BacktestConfig::allow_short`](crate::backtesting::BacktestConfig)).
40#[derive(Debug, Clone)]
41pub struct SmaCrossover {
42    /// Fast SMA period
43    pub fast_period: usize,
44    /// Slow SMA period
45    pub slow_period: usize,
46}
47
48impl SmaCrossover {
49    /// Create a new SMA crossover strategy
50    pub fn new(fast_period: usize, slow_period: usize) -> Self {
51        Self {
52            fast_period,
53            slow_period,
54        }
55    }
56
57    fn fast_key(&self) -> String {
58        format!("sma_{}", self.fast_period)
59    }
60
61    fn slow_key(&self) -> String {
62        format!("sma_{}", self.slow_period)
63    }
64}
65
66impl Default for SmaCrossover {
67    fn default() -> Self {
68        Self::new(10, 20)
69    }
70}
71
72impl Strategy for SmaCrossover {
73    fn name(&self) -> &str {
74        "SMA Crossover"
75    }
76
77    fn required_indicators(&self) -> Vec<(String, Indicator)> {
78        vec![
79            (self.fast_key(), Indicator::Sma(self.fast_period)),
80            (self.slow_key(), Indicator::Sma(self.slow_period)),
81        ]
82    }
83
84    fn warmup_period(&self) -> usize {
85        self.slow_period.max(self.fast_period) + 1
86    }
87
88    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
89        let candle = ctx.current_candle();
90
91        // Bullish crossover: fast crosses above slow
92        if ctx.crossed_above(&self.fast_key(), &self.slow_key()) {
93            if ctx.is_short() {
94                return Signal::exit(candle.timestamp, candle.close)
95                    .with_reason("SMA bullish crossover - close short");
96            }
97            if !ctx.has_position() {
98                return Signal::long(candle.timestamp, candle.close)
99                    .with_reason("SMA bullish crossover");
100            }
101        }
102
103        // Bearish crossover: fast crosses below slow
104        if ctx.crossed_below(&self.fast_key(), &self.slow_key()) {
105            if ctx.is_long() {
106                return Signal::exit(candle.timestamp, candle.close)
107                    .with_reason("SMA bearish crossover - close long");
108            }
109            if !ctx.has_position() {
110                return Signal::short(candle.timestamp, candle.close)
111                    .with_reason("SMA bearish crossover");
112            }
113        }
114
115        Signal::hold()
116    }
117}
118
119/// RSI Reversal Strategy
120///
121/// Goes long when RSI crosses above oversold level.
122/// Exits when RSI reaches overbought level.
123/// Emits short signals when RSI crosses below overbought (execution gated by
124/// [`BacktestConfig::allow_short`](crate::backtesting::BacktestConfig)).
125#[derive(Debug, Clone)]
126pub struct RsiReversal {
127    /// RSI period
128    pub period: usize,
129    /// Oversold threshold (default 30)
130    pub oversold: f64,
131    /// Overbought threshold (default 70)
132    pub overbought: f64,
133}
134
135impl RsiReversal {
136    /// Create a new RSI reversal strategy
137    pub fn new(period: usize) -> Self {
138        Self {
139            period,
140            oversold: 30.0,
141            overbought: 70.0,
142        }
143    }
144
145    /// Set custom oversold/overbought thresholds
146    pub fn with_thresholds(mut self, oversold: f64, overbought: f64) -> Self {
147        self.oversold = oversold;
148        self.overbought = overbought;
149        self
150    }
151
152    fn rsi_key(&self) -> String {
153        format!("rsi_{}", self.period)
154    }
155}
156
157impl Default for RsiReversal {
158    fn default() -> Self {
159        Self::new(14)
160    }
161}
162
163impl Strategy for RsiReversal {
164    fn name(&self) -> &str {
165        "RSI Reversal"
166    }
167
168    fn required_indicators(&self) -> Vec<(String, Indicator)> {
169        vec![(self.rsi_key(), Indicator::Rsi(self.period))]
170    }
171
172    fn warmup_period(&self) -> usize {
173        self.period + 1
174    }
175
176    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
177        let candle = ctx.current_candle();
178        let rsi = ctx.indicator(&self.rsi_key());
179
180        let Some(rsi_val) = rsi else {
181            return Signal::hold();
182        };
183
184        // Calculate signal strength based on RSI extremity
185        let strength = if !(20.0..=80.0).contains(&rsi_val) {
186            SignalStrength::strong()
187        } else if !(25.0..=75.0).contains(&rsi_val) {
188            SignalStrength::medium()
189        } else {
190            SignalStrength::weak()
191        };
192
193        // Bullish: RSI crosses above oversold
194        if ctx.indicator_crossed_above(&self.rsi_key(), self.oversold) {
195            if ctx.is_short() {
196                return Signal::exit(candle.timestamp, candle.close)
197                    .with_strength(strength)
198                    .with_reason(format!(
199                        "RSI crossed above {:.0} - close short",
200                        self.oversold
201                    ));
202            }
203            if !ctx.has_position() {
204                return Signal::long(candle.timestamp, candle.close)
205                    .with_strength(strength)
206                    .with_reason(format!("RSI crossed above {:.0}", self.oversold));
207            }
208        }
209
210        // Bearish: RSI crosses below overbought
211        if ctx.indicator_crossed_below(&self.rsi_key(), self.overbought) {
212            if ctx.is_long() {
213                return Signal::exit(candle.timestamp, candle.close)
214                    .with_strength(strength)
215                    .with_reason(format!(
216                        "RSI crossed below {:.0} - close long",
217                        self.overbought
218                    ));
219            }
220            if !ctx.has_position() {
221                return Signal::short(candle.timestamp, candle.close)
222                    .with_strength(strength)
223                    .with_reason(format!("RSI crossed below {:.0}", self.overbought));
224            }
225        }
226
227        Signal::hold()
228    }
229}
230
231/// MACD Signal Strategy
232///
233/// Goes long when MACD line crosses above signal line.
234/// Exits when MACD line crosses below signal line.
235/// Emits short signals on bearish crossovers (execution gated by
236/// [`BacktestConfig::allow_short`](crate::backtesting::BacktestConfig)).
237#[derive(Debug, Clone)]
238pub struct MacdSignal {
239    /// Fast EMA period
240    pub fast: usize,
241    /// Slow EMA period
242    pub slow: usize,
243    /// Signal line period
244    pub signal: usize,
245}
246
247impl MacdSignal {
248    /// Create a new MACD signal strategy
249    pub fn new(fast: usize, slow: usize, signal: usize) -> Self {
250        Self { fast, slow, signal }
251    }
252}
253
254impl Default for MacdSignal {
255    fn default() -> Self {
256        Self::new(12, 26, 9)
257    }
258}
259
260impl Strategy for MacdSignal {
261    fn name(&self) -> &str {
262        "MACD Signal"
263    }
264
265    fn required_indicators(&self) -> Vec<(String, Indicator)> {
266        vec![(
267            "macd".to_string(),
268            Indicator::Macd {
269                fast: self.fast,
270                slow: self.slow,
271                signal: self.signal,
272            },
273        )]
274    }
275
276    fn warmup_period(&self) -> usize {
277        self.slow + self.signal
278    }
279
280    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
281        let candle = ctx.current_candle();
282
283        let line_key = format!("macd_line_{}_{}_{}", self.fast, self.slow, self.signal);
284        let sig_key = format!("macd_signal_{}_{}_{}", self.fast, self.slow, self.signal);
285
286        // MACD line and signal line are stored separately by the engine
287        // Bullish crossover
288        if ctx.crossed_above(&line_key, &sig_key) {
289            if ctx.is_short() {
290                return Signal::exit(candle.timestamp, candle.close)
291                    .with_reason("MACD bullish crossover - close short");
292            }
293            if !ctx.has_position() {
294                return Signal::long(candle.timestamp, candle.close)
295                    .with_reason("MACD bullish crossover");
296            }
297        }
298
299        // Bearish crossover
300        if ctx.crossed_below(&line_key, &sig_key) {
301            if ctx.is_long() {
302                return Signal::exit(candle.timestamp, candle.close)
303                    .with_reason("MACD bearish crossover - close long");
304            }
305            if !ctx.has_position() {
306                return Signal::short(candle.timestamp, candle.close)
307                    .with_reason("MACD bearish crossover");
308            }
309        }
310
311        Signal::hold()
312    }
313}
314
315/// Bollinger Bands Mean Reversion Strategy
316///
317/// Goes long when price touches lower band (oversold).
318/// Exits when price reaches middle or upper band.
319/// Emits short signals when price touches upper band (execution gated by
320/// [`BacktestConfig::allow_short`](crate::backtesting::BacktestConfig)).
321///
322/// # Signal Strength
323///
324/// All entry signals emit at default strength (`1.0`). Strength is **not** scaled
325/// by how far price has penetrated through the band. This differs from
326/// [`RsiReversal`], which grades strength by RSI extremity. If you are relying
327/// on [`BacktestConfig::min_signal_strength`] to filter signals in a portfolio
328/// context, all Bollinger entries will pass the threshold equally.
329#[derive(Debug, Clone)]
330pub struct BollingerMeanReversion {
331    /// SMA period for middle band
332    pub period: usize,
333    /// Standard deviation multiplier
334    pub std_dev: f64,
335    /// Exit at middle band (true) or upper/lower band (false)
336    pub exit_at_middle: bool,
337}
338
339impl BollingerMeanReversion {
340    /// Create a new Bollinger mean reversion strategy
341    pub fn new(period: usize, std_dev: f64) -> Self {
342        Self {
343            period,
344            std_dev,
345            exit_at_middle: true,
346        }
347    }
348
349    /// Set exit target (middle band or opposite band)
350    pub fn exit_at_middle(mut self, at_middle: bool) -> Self {
351        self.exit_at_middle = at_middle;
352        self
353    }
354}
355
356impl Default for BollingerMeanReversion {
357    fn default() -> Self {
358        Self::new(20, 2.0)
359    }
360}
361
362impl Strategy for BollingerMeanReversion {
363    fn name(&self) -> &str {
364        "Bollinger Mean Reversion"
365    }
366
367    fn required_indicators(&self) -> Vec<(String, Indicator)> {
368        vec![(
369            "bollinger".to_string(),
370            Indicator::Bollinger {
371                period: self.period,
372                std_dev: self.std_dev,
373            },
374        )]
375    }
376
377    fn warmup_period(&self) -> usize {
378        self.period
379    }
380
381    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
382        let candle = ctx.current_candle();
383        let close = candle.close;
384
385        let lower = ctx.indicator(&format!("bollinger_lower_{}_{}", self.period, self.std_dev));
386        let middle = ctx.indicator(&format!(
387            "bollinger_middle_{}_{}",
388            self.period, self.std_dev
389        ));
390        let upper = ctx.indicator(&format!("bollinger_upper_{}_{}", self.period, self.std_dev));
391
392        let (Some(lower_val), Some(middle_val), Some(upper_val)) = (lower, middle, upper) else {
393            return Signal::hold();
394        };
395
396        // Long entry: price at or below lower band
397        if close <= lower_val && !ctx.has_position() {
398            return Signal::long(candle.timestamp, close)
399                .with_reason("Price at lower Bollinger Band");
400        }
401
402        // Long exit
403        if ctx.is_long() {
404            let exit_level = if self.exit_at_middle {
405                middle_val
406            } else {
407                upper_val
408            };
409            if close >= exit_level {
410                return Signal::exit(candle.timestamp, close).with_reason(format!(
411                    "Price reached {} Bollinger Band",
412                    if self.exit_at_middle {
413                        "middle"
414                    } else {
415                        "upper"
416                    }
417                ));
418            }
419        }
420
421        // Short entry: price at or above upper band
422        if close >= upper_val && !ctx.has_position() {
423            return Signal::short(candle.timestamp, close)
424                .with_reason("Price at upper Bollinger Band");
425        }
426
427        // Short exit
428        if ctx.is_short() {
429            let exit_level = if self.exit_at_middle {
430                middle_val
431            } else {
432                lower_val
433            };
434            if close <= exit_level {
435                return Signal::exit(candle.timestamp, close).with_reason(format!(
436                    "Price reached {} Bollinger Band",
437                    if self.exit_at_middle {
438                        "middle"
439                    } else {
440                        "lower"
441                    }
442                ));
443            }
444        }
445
446        Signal::hold()
447    }
448}
449
450/// SuperTrend Following Strategy
451///
452/// Goes long when SuperTrend turns bullish (uptrend).
453/// Emits short signals when SuperTrend turns bearish (execution gated by
454/// [`BacktestConfig::allow_short`](crate::backtesting::BacktestConfig)).
455#[derive(Debug, Clone)]
456pub struct SuperTrendFollow {
457    /// ATR period
458    pub period: usize,
459    /// ATR multiplier
460    pub multiplier: f64,
461}
462
463impl SuperTrendFollow {
464    /// Create a new SuperTrend following strategy
465    pub fn new(period: usize, multiplier: f64) -> Self {
466        Self { period, multiplier }
467    }
468}
469
470impl Default for SuperTrendFollow {
471    fn default() -> Self {
472        Self::new(10, 3.0)
473    }
474}
475
476impl Strategy for SuperTrendFollow {
477    fn name(&self) -> &str {
478        "SuperTrend Follow"
479    }
480
481    fn required_indicators(&self) -> Vec<(String, Indicator)> {
482        vec![(
483            "supertrend".to_string(),
484            Indicator::Supertrend {
485                period: self.period,
486                multiplier: self.multiplier,
487            },
488        )]
489    }
490
491    fn warmup_period(&self) -> usize {
492        self.period + 1
493    }
494
495    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
496        let candle = ctx.current_candle();
497
498        // SuperTrend uptrend stored as 1.0, downtrend as 0.0
499        let uptrend_key = format!("supertrend_uptrend_{}_{}", self.period, self.multiplier);
500        let trend_now = ctx.indicator(&uptrend_key);
501        let trend_prev = ctx.indicator_prev(&uptrend_key);
502
503        let (Some(now), Some(prev)) = (trend_now, trend_prev) else {
504            return Signal::hold();
505        };
506
507        let is_uptrend = now > 0.5;
508        let was_uptrend = prev > 0.5;
509
510        // Trend changed to bullish
511        if is_uptrend && !was_uptrend {
512            if ctx.is_short() {
513                return Signal::exit(candle.timestamp, candle.close)
514                    .with_reason("SuperTrend turned bullish - close short");
515            }
516            if !ctx.has_position() {
517                return Signal::long(candle.timestamp, candle.close)
518                    .with_reason("SuperTrend turned bullish");
519            }
520        }
521
522        // Trend changed to bearish
523        if !is_uptrend && was_uptrend {
524            if ctx.is_long() {
525                return Signal::exit(candle.timestamp, candle.close)
526                    .with_reason("SuperTrend turned bearish - close long");
527            }
528            if !ctx.has_position() {
529                return Signal::short(candle.timestamp, candle.close)
530                    .with_reason("SuperTrend turned bearish");
531            }
532        }
533
534        Signal::hold()
535    }
536}
537
538/// Donchian Channel Breakout Strategy
539///
540/// Goes long when price breaks above upper channel (new high).
541/// Exits when price breaks below lower channel (new low).
542/// Emits short signals on downward breakouts (execution gated by
543/// [`BacktestConfig::allow_short`](crate::backtesting::BacktestConfig)).
544#[derive(Debug, Clone)]
545pub struct DonchianBreakout {
546    /// Channel period
547    pub period: usize,
548    /// Use middle channel for exit (true) or opposite channel (false)
549    pub exit_at_middle: bool,
550}
551
552impl DonchianBreakout {
553    /// Create a new Donchian breakout strategy
554    pub fn new(period: usize) -> Self {
555        Self {
556            period,
557            exit_at_middle: true,
558        }
559    }
560
561    /// Set exit at middle channel
562    pub fn exit_at_middle(mut self, at_middle: bool) -> Self {
563        self.exit_at_middle = at_middle;
564        self
565    }
566}
567
568impl Default for DonchianBreakout {
569    fn default() -> Self {
570        Self::new(20)
571    }
572}
573
574impl Strategy for DonchianBreakout {
575    fn name(&self) -> &str {
576        "Donchian Breakout"
577    }
578
579    fn required_indicators(&self) -> Vec<(String, Indicator)> {
580        vec![(
581            "donchian".to_string(),
582            Indicator::DonchianChannels(self.period),
583        )]
584    }
585
586    fn warmup_period(&self) -> usize {
587        self.period
588    }
589
590    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
591        let candle = ctx.current_candle();
592        let close = candle.close;
593
594        let upper_key = format!("donchian_upper_{}", self.period);
595        let middle_key = format!("donchian_middle_{}", self.period);
596        let lower_key = format!("donchian_lower_{}", self.period);
597        let upper = ctx.indicator(&upper_key);
598        let middle = ctx.indicator(&middle_key);
599        let lower = ctx.indicator(&lower_key);
600        let prev_upper = ctx.indicator_prev(&upper_key);
601        let prev_lower = ctx.indicator_prev(&lower_key);
602
603        let (Some(_upper_val), Some(middle_val), Some(_lower_val)) = (upper, middle, lower) else {
604            return Signal::hold();
605        };
606
607        // Breakout above the *previous* bar's upper channel level → go long.
608        // Using the lagged level rather than the current bar's channel prevents
609        // look-ahead bias: the current bar's Donchian high is computed using the
610        // close of that same bar, so comparing `close > current_upper` would
611        // trivially never trigger (the close can equal but not exceed the max
612        // of the window it belongs to).  The lagged level is the natural
613        // reference point for a confirmed breakout signal.
614        if let Some(prev_up) = prev_upper
615            && close > prev_up
616            && !ctx.has_position()
617        {
618            return Signal::long(candle.timestamp, close)
619                .with_reason("Donchian upper channel breakout");
620        }
621
622        // Breakdown below the *previous* bar's lower channel level (same
623        // lagged-reference rationale as the upper channel breakout above).
624        if let Some(prev_low) = prev_lower
625            && close < prev_low
626        {
627            if ctx.is_long() {
628                return Signal::exit(candle.timestamp, close)
629                    .with_reason("Donchian lower channel breakdown - close long");
630            }
631            if !ctx.has_position() {
632                return Signal::short(candle.timestamp, close)
633                    .with_reason("Donchian lower channel breakdown");
634            }
635        }
636
637        // Exit long at middle
638        if ctx.is_long() && self.exit_at_middle && close <= middle_val {
639            return Signal::exit(candle.timestamp, close)
640                .with_reason("Price reached Donchian middle channel");
641        }
642
643        // Exit short at middle
644        if ctx.is_short() && self.exit_at_middle && close >= middle_val {
645            return Signal::exit(candle.timestamp, close)
646                .with_reason("Price reached Donchian middle channel");
647        }
648
649        Signal::hold()
650    }
651}
652
653#[cfg(test)]
654mod tests {
655    use super::*;
656
657    #[test]
658    fn test_sma_crossover_default() {
659        let s = SmaCrossover::default();
660        assert_eq!(s.fast_period, 10);
661        assert_eq!(s.slow_period, 20);
662    }
663
664    #[test]
665    fn test_sma_crossover_custom() {
666        let s = SmaCrossover::new(5, 15);
667        assert_eq!(s.fast_period, 5);
668        assert_eq!(s.slow_period, 15);
669    }
670
671    #[test]
672    fn test_rsi_default() {
673        let s = RsiReversal::default();
674        assert_eq!(s.period, 14);
675        assert!((s.oversold - 30.0).abs() < 0.01);
676        assert!((s.overbought - 70.0).abs() < 0.01);
677    }
678
679    #[test]
680    fn test_rsi_with_thresholds() {
681        let s = RsiReversal::new(10).with_thresholds(25.0, 75.0);
682        assert_eq!(s.period, 10);
683        assert!((s.oversold - 25.0).abs() < 0.01);
684        assert!((s.overbought - 75.0).abs() < 0.01);
685    }
686
687    #[test]
688    fn test_macd_default() {
689        let s = MacdSignal::default();
690        assert_eq!(s.fast, 12);
691        assert_eq!(s.slow, 26);
692        assert_eq!(s.signal, 9);
693    }
694
695    #[test]
696    fn test_bollinger_default() {
697        let s = BollingerMeanReversion::default();
698        assert_eq!(s.period, 20);
699        assert!((s.std_dev - 2.0).abs() < 0.01);
700    }
701
702    #[test]
703    fn test_supertrend_default() {
704        let s = SuperTrendFollow::default();
705        assert_eq!(s.period, 10);
706        assert!((s.multiplier - 3.0).abs() < 0.01);
707    }
708
709    #[test]
710    fn test_donchian_default() {
711        let s = DonchianBreakout::default();
712        assert_eq!(s.period, 20);
713        assert!(s.exit_at_middle);
714    }
715
716    #[test]
717    fn test_strategy_names() {
718        assert_eq!(SmaCrossover::default().name(), "SMA Crossover");
719        assert_eq!(RsiReversal::default().name(), "RSI Reversal");
720        assert_eq!(MacdSignal::default().name(), "MACD Signal");
721        assert_eq!(
722            BollingerMeanReversion::default().name(),
723            "Bollinger Mean Reversion"
724        );
725        assert_eq!(SuperTrendFollow::default().name(), "SuperTrend Follow");
726        assert_eq!(DonchianBreakout::default().name(), "Donchian Breakout");
727    }
728
729    #[test]
730    fn test_required_indicators() {
731        let sma = SmaCrossover::new(5, 10);
732        let indicators = sma.required_indicators();
733        assert_eq!(indicators.len(), 2);
734        assert_eq!(indicators[0].0, "sma_5");
735        assert_eq!(indicators[1].0, "sma_10");
736
737        let rsi = RsiReversal::new(14);
738        let indicators = rsi.required_indicators();
739        assert_eq!(indicators.len(), 1);
740        assert_eq!(indicators[0].0, "rsi_14");
741    }
742}