Skip to main content

finance_query/backtesting/strategy/
prebuilt.rs

1//! Pre-built trading strategies.
2//!
3//! Ready-to-use strategy implementations that can be used directly with the backtest engine.
4//! Each strategy implements the [`Strategy`] trait and can be customized via builder methods.
5//!
6//! Short signals are always emitted when the condition is met. Whether they are
7//! *executed* is controlled solely by [`BacktestConfig::allow_short`](crate::backtesting::BacktestConfig).
8//!
9//! # Available Strategies
10//!
11//! | Strategy | Description |
12//! |----------|-------------|
13//! | [`SmaCrossover`] | Dual SMA crossover (trend following) |
14//! | [`RsiReversal`] | RSI mean reversion |
15//! | [`MacdSignal`] | MACD line crossover |
16//! | [`BollingerMeanReversion`] | Bollinger Bands mean reversion |
17//! | [`SuperTrendFollow`] | SuperTrend trend following |
18//! | [`DonchianBreakout`] | Donchian channel breakout |
19//!
20//! # Example
21//!
22//! ```ignore
23//! use finance_query::backtesting::{SmaCrossover, BacktestConfig};
24//!
25//! let strategy = SmaCrossover::new(10, 20);
26//! let config = BacktestConfig::builder().allow_short(true).build().unwrap();
27//! ```
28
29use std::collections::HashMap;
30
31use crate::indicators::Indicator;
32
33use super::{Signal, Strategy, StrategyContext};
34use crate::backtesting::signal::SignalStrength;
35
36// ── Per-key pointer cache ─────────────────────────────────────────────────────
37
38/// One-slot pointer cache for a pre-computed indicator `Vec`.
39///
40/// Set once by [`Strategy::setup`] before the simulation loop; dereferenced on
41/// every bar in [`Strategy::on_candle`].  Clones as `None` so a cloned strategy
42/// is safe to pass to a fresh `BacktestEngine::run` call (the engine will call
43/// `setup` again).
44///
45/// # Safety invariant
46/// The pointer is valid for the duration of the enclosing `simulate()` call: it
47/// is taken from the `indicators` HashMap which is owned by that frame, never
48/// mutated during the loop, and outlives all `on_candle` calls.
49#[derive(Debug, Default)]
50struct IndicatorSlot(Option<*const Vec<Option<f64>>>);
51
52// SAFETY: The pointer is only read inside the engine's simulation loop (single-
53// threaded).  We never send it across threads while it could be dangling.
54unsafe impl Send for IndicatorSlot {}
55unsafe impl Sync for IndicatorSlot {}
56
57impl Clone for IndicatorSlot {
58    /// Returns an empty slot — the clone must go through `setup()` before use.
59    fn clone(&self) -> Self {
60        IndicatorSlot(None)
61    }
62}
63
64impl IndicatorSlot {
65    fn set(&mut self, v: &Vec<Option<f64>>) {
66        self.0 = Some(v as *const _);
67    }
68
69    /// Returns the cached slice, if set.
70    ///
71    /// # Safety
72    /// Must only be called during a simulation loop whose `setup()` populated
73    /// this slot from a HashMap that is still alive and unmodified.
74    #[inline]
75    unsafe fn get(&self) -> Option<&Vec<Option<f64>>> {
76        self.0.map(|p| unsafe { &*p })
77    }
78}
79
80/// SMA Crossover Strategy
81///
82/// Goes long when fast SMA crosses above slow SMA.
83/// Exits when fast SMA crosses below slow SMA.
84/// Emits short signals on bearish crossovers (execution gated by
85/// [`BacktestConfig::allow_short`](crate::backtesting::BacktestConfig)).
86#[derive(Debug, Clone)]
87pub struct SmaCrossover {
88    /// Fast SMA period
89    pub fast_period: usize,
90    /// Slow SMA period
91    pub slow_period: usize,
92    fast_key: String,
93    slow_key: String,
94    fast_slot: IndicatorSlot,
95    slow_slot: IndicatorSlot,
96}
97
98impl SmaCrossover {
99    /// Create a new SMA crossover strategy
100    pub fn new(fast_period: usize, slow_period: usize) -> Self {
101        Self {
102            fast_period,
103            slow_period,
104            fast_key: format!("sma_{fast_period}"),
105            slow_key: format!("sma_{slow_period}"),
106            fast_slot: IndicatorSlot::default(),
107            slow_slot: IndicatorSlot::default(),
108        }
109    }
110}
111
112impl Default for SmaCrossover {
113    fn default() -> Self {
114        Self::new(10, 20)
115    }
116}
117
118impl Strategy for SmaCrossover {
119    fn name(&self) -> &str {
120        "SMA Crossover"
121    }
122
123    fn required_indicators(&self) -> Vec<(String, Indicator)> {
124        vec![
125            (self.fast_key.clone(), Indicator::Sma(self.fast_period)),
126            (self.slow_key.clone(), Indicator::Sma(self.slow_period)),
127        ]
128    }
129
130    fn setup(&mut self, indicators: &HashMap<String, Vec<Option<f64>>>) {
131        if let Some(v) = indicators.get(&self.fast_key) {
132            self.fast_slot.set(v);
133        }
134        if let Some(v) = indicators.get(&self.slow_key) {
135            self.slow_slot.set(v);
136        }
137    }
138
139    fn warmup_period(&self) -> usize {
140        self.slow_period.max(self.fast_period) + 1
141    }
142
143    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
144        let candle = ctx.current_candle();
145        let i = ctx.index;
146        if i == 0 {
147            return Signal::hold();
148        }
149
150        // Use cached pointer (0 HashMap lookups); fall back to map lookup if
151        // setup() was not called (e.g., strategy used outside the engine).
152        // SAFETY: setup() was called from simulate() with the indicators map
153        // that is alive and unmodified for the duration of the loop.
154        let fast_vals =
155            unsafe { self.fast_slot.get() }.or_else(|| ctx.indicators.get(&self.fast_key));
156        let slow_vals =
157            unsafe { self.slow_slot.get() }.or_else(|| ctx.indicators.get(&self.slow_key));
158        let (Some(fast_vals), Some(slow_vals)) = (fast_vals, slow_vals) else {
159            return Signal::hold();
160        };
161
162        let get = |vals: &Vec<Option<f64>>, idx: usize| vals.get(idx).and_then(|&v| v);
163        let (Some(fn_), Some(sn), Some(fp), Some(sp)) = (
164            get(fast_vals, i),
165            get(slow_vals, i),
166            get(fast_vals, i - 1),
167            get(slow_vals, i - 1),
168        ) else {
169            return Signal::hold();
170        };
171
172        // Bullish crossover: fast crosses above slow
173        if fp < sp && fn_ > sn {
174            if ctx.is_short() {
175                return Signal::exit(candle.timestamp, candle.close)
176                    .with_reason("SMA bullish crossover - close short");
177            }
178            if !ctx.has_position() {
179                return Signal::long(candle.timestamp, candle.close)
180                    .with_reason("SMA bullish crossover");
181            }
182        }
183
184        // Bearish crossover: fast crosses below slow
185        if fp > sp && fn_ < sn {
186            if ctx.is_long() {
187                return Signal::exit(candle.timestamp, candle.close)
188                    .with_reason("SMA bearish crossover - close long");
189            }
190            if !ctx.has_position() {
191                return Signal::short(candle.timestamp, candle.close)
192                    .with_reason("SMA bearish crossover");
193            }
194        }
195
196        Signal::hold()
197    }
198}
199
200/// RSI Reversal Strategy
201///
202/// Goes long when RSI crosses above oversold level.
203/// Exits when RSI reaches overbought level.
204/// Emits short signals when RSI crosses below overbought (execution gated by
205/// [`BacktestConfig::allow_short`](crate::backtesting::BacktestConfig)).
206#[derive(Debug, Clone)]
207pub struct RsiReversal {
208    /// RSI period
209    pub period: usize,
210    /// Oversold threshold (default 30)
211    pub oversold: f64,
212    /// Overbought threshold (default 70)
213    pub overbought: f64,
214    rsi_key: String,
215    rsi_slot: IndicatorSlot,
216}
217
218impl RsiReversal {
219    /// Create a new RSI reversal strategy
220    pub fn new(period: usize) -> Self {
221        Self {
222            period,
223            oversold: 30.0,
224            overbought: 70.0,
225            rsi_key: format!("rsi_{period}"),
226            rsi_slot: IndicatorSlot::default(),
227        }
228    }
229
230    /// Set custom oversold/overbought thresholds
231    pub fn with_thresholds(mut self, oversold: f64, overbought: f64) -> Self {
232        self.oversold = oversold;
233        self.overbought = overbought;
234        self
235    }
236}
237
238impl Default for RsiReversal {
239    fn default() -> Self {
240        Self::new(14)
241    }
242}
243
244impl Strategy for RsiReversal {
245    fn name(&self) -> &str {
246        "RSI Reversal"
247    }
248
249    fn required_indicators(&self) -> Vec<(String, Indicator)> {
250        vec![(self.rsi_key.clone(), Indicator::Rsi(self.period))]
251    }
252
253    fn setup(&mut self, indicators: &HashMap<String, Vec<Option<f64>>>) {
254        if let Some(v) = indicators.get(&self.rsi_key) {
255            self.rsi_slot.set(v);
256        }
257    }
258
259    fn warmup_period(&self) -> usize {
260        self.period + 1
261    }
262
263    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
264        let candle = ctx.current_candle();
265        let i = ctx.index;
266
267        // SAFETY: see SmaCrossover::on_candle.
268        let rsi_vals = unsafe { self.rsi_slot.get() }.or_else(|| ctx.indicators.get(&self.rsi_key));
269        let Some(rsi_vals) = rsi_vals else {
270            return Signal::hold();
271        };
272        let get = |idx: usize| rsi_vals.get(idx).and_then(|&v| v);
273        let Some(rsi_val) = get(i) else {
274            return Signal::hold();
275        };
276        let rsi_prev = if i > 0 { get(i - 1) } else { None };
277
278        // Calculate signal strength based on RSI extremity
279        let strength = if !(20.0..=80.0).contains(&rsi_val) {
280            SignalStrength::strong()
281        } else if !(25.0..=75.0).contains(&rsi_val) {
282            SignalStrength::medium()
283        } else {
284            SignalStrength::weak()
285        };
286
287        // Bullish: RSI crosses above oversold
288        let crossed_above_oversold =
289            rsi_prev.is_some_and(|p| p <= self.oversold) && rsi_val > self.oversold;
290        if crossed_above_oversold {
291            if ctx.is_short() {
292                return Signal::exit(candle.timestamp, candle.close)
293                    .with_strength(strength)
294                    .with_reason(format!(
295                        "RSI crossed above {:.0} - close short",
296                        self.oversold
297                    ));
298            }
299            if !ctx.has_position() {
300                return Signal::long(candle.timestamp, candle.close)
301                    .with_strength(strength)
302                    .with_reason(format!("RSI crossed above {:.0}", self.oversold));
303            }
304        }
305
306        // Bearish: RSI crosses below overbought
307        let crossed_below_overbought =
308            rsi_prev.is_some_and(|p| p >= self.overbought) && rsi_val < self.overbought;
309        if crossed_below_overbought {
310            if ctx.is_long() {
311                return Signal::exit(candle.timestamp, candle.close)
312                    .with_strength(strength)
313                    .with_reason(format!(
314                        "RSI crossed below {:.0} - close long",
315                        self.overbought
316                    ));
317            }
318            if !ctx.has_position() {
319                return Signal::short(candle.timestamp, candle.close)
320                    .with_strength(strength)
321                    .with_reason(format!("RSI crossed below {:.0}", self.overbought));
322            }
323        }
324
325        Signal::hold()
326    }
327}
328
329/// MACD Signal Strategy
330///
331/// Goes long when MACD line crosses above signal line.
332/// Exits when MACD line crosses below signal line.
333/// Emits short signals on bearish crossovers (execution gated by
334/// [`BacktestConfig::allow_short`](crate::backtesting::BacktestConfig)).
335#[derive(Debug, Clone)]
336pub struct MacdSignal {
337    /// Fast EMA period
338    pub fast: usize,
339    /// Slow EMA period
340    pub slow: usize,
341    /// Signal line period
342    pub signal: usize,
343    line_key: String,
344    sig_key: String,
345    line_slot: IndicatorSlot,
346    sig_slot: IndicatorSlot,
347}
348
349impl MacdSignal {
350    /// Create a new MACD signal strategy
351    pub fn new(fast: usize, slow: usize, signal: usize) -> Self {
352        Self {
353            fast,
354            slow,
355            signal,
356            line_key: format!("macd_line_{fast}_{slow}_{signal}"),
357            sig_key: format!("macd_signal_{fast}_{slow}_{signal}"),
358            line_slot: IndicatorSlot::default(),
359            sig_slot: IndicatorSlot::default(),
360        }
361    }
362}
363
364impl Default for MacdSignal {
365    fn default() -> Self {
366        Self::new(12, 26, 9)
367    }
368}
369
370impl Strategy for MacdSignal {
371    fn name(&self) -> &str {
372        "MACD Signal"
373    }
374
375    fn required_indicators(&self) -> Vec<(String, Indicator)> {
376        vec![(
377            "macd".to_string(),
378            Indicator::Macd {
379                fast: self.fast,
380                slow: self.slow,
381                signal: self.signal,
382            },
383        )]
384    }
385
386    fn setup(&mut self, indicators: &HashMap<String, Vec<Option<f64>>>) {
387        if let Some(v) = indicators.get(&self.line_key) {
388            self.line_slot.set(v);
389        }
390        if let Some(v) = indicators.get(&self.sig_key) {
391            self.sig_slot.set(v);
392        }
393    }
394
395    fn warmup_period(&self) -> usize {
396        self.slow + self.signal
397    }
398
399    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
400        let candle = ctx.current_candle();
401        let i = ctx.index;
402        if i == 0 {
403            return Signal::hold();
404        }
405
406        // SAFETY: see SmaCrossover::on_candle.
407        let line_vals =
408            unsafe { self.line_slot.get() }.or_else(|| ctx.indicators.get(&self.line_key));
409        let sig_vals = unsafe { self.sig_slot.get() }.or_else(|| ctx.indicators.get(&self.sig_key));
410        let (Some(line_vals), Some(sig_vals)) = (line_vals, sig_vals) else {
411            return Signal::hold();
412        };
413
414        let get = |vals: &Vec<Option<f64>>, idx: usize| vals.get(idx).and_then(|&v| v);
415        let (Some(ln), Some(sn), Some(lp), Some(sp)) = (
416            get(line_vals, i),
417            get(sig_vals, i),
418            get(line_vals, i - 1),
419            get(sig_vals, i - 1),
420        ) else {
421            return Signal::hold();
422        };
423
424        // MACD line and signal line are stored separately by the engine
425        // Bullish crossover
426        if lp < sp && ln > sn {
427            if ctx.is_short() {
428                return Signal::exit(candle.timestamp, candle.close)
429                    .with_reason("MACD bullish crossover - close short");
430            }
431            if !ctx.has_position() {
432                return Signal::long(candle.timestamp, candle.close)
433                    .with_reason("MACD bullish crossover");
434            }
435        }
436
437        // Bearish crossover
438        if lp > sp && ln < sn {
439            if ctx.is_long() {
440                return Signal::exit(candle.timestamp, candle.close)
441                    .with_reason("MACD bearish crossover - close long");
442            }
443            if !ctx.has_position() {
444                return Signal::short(candle.timestamp, candle.close)
445                    .with_reason("MACD bearish crossover");
446            }
447        }
448
449        Signal::hold()
450    }
451}
452
453/// Bollinger Bands Mean Reversion Strategy
454///
455/// Goes long when price touches lower band (oversold).
456/// Exits when price reaches middle or upper band.
457/// Emits short signals when price touches upper band (execution gated by
458/// [`BacktestConfig::allow_short`](crate::backtesting::BacktestConfig)).
459///
460/// # Signal Strength
461///
462/// All entry signals emit at default strength (`1.0`). Strength is **not** scaled
463/// by how far price has penetrated through the band. This differs from
464/// [`RsiReversal`], which grades strength by RSI extremity. If you are relying
465/// on [`BacktestConfig::min_signal_strength`] to filter signals in a portfolio
466/// context, all Bollinger entries will pass the threshold equally.
467#[derive(Debug, Clone)]
468pub struct BollingerMeanReversion {
469    /// SMA period for middle band
470    pub period: usize,
471    /// Standard deviation multiplier
472    pub std_dev: f64,
473    /// Exit at middle band (true) or upper/lower band (false)
474    pub exit_at_middle: bool,
475    lower_key: String,
476    middle_key: String,
477    upper_key: String,
478    lower_slot: IndicatorSlot,
479    middle_slot: IndicatorSlot,
480    upper_slot: IndicatorSlot,
481}
482
483impl BollingerMeanReversion {
484    /// Create a new Bollinger mean reversion strategy
485    pub fn new(period: usize, std_dev: f64) -> Self {
486        Self {
487            period,
488            std_dev,
489            exit_at_middle: true,
490            lower_key: format!("bollinger_lower_{period}_{std_dev}"),
491            middle_key: format!("bollinger_middle_{period}_{std_dev}"),
492            upper_key: format!("bollinger_upper_{period}_{std_dev}"),
493            lower_slot: IndicatorSlot::default(),
494            middle_slot: IndicatorSlot::default(),
495            upper_slot: IndicatorSlot::default(),
496        }
497    }
498
499    /// Set exit target (middle band or opposite band)
500    pub fn exit_at_middle(mut self, at_middle: bool) -> Self {
501        self.exit_at_middle = at_middle;
502        self
503    }
504}
505
506impl Default for BollingerMeanReversion {
507    fn default() -> Self {
508        Self::new(20, 2.0)
509    }
510}
511
512impl Strategy for BollingerMeanReversion {
513    fn name(&self) -> &str {
514        "Bollinger Mean Reversion"
515    }
516
517    fn required_indicators(&self) -> Vec<(String, Indicator)> {
518        vec![(
519            "bollinger".to_string(),
520            Indicator::Bollinger {
521                period: self.period,
522                std_dev: self.std_dev,
523            },
524        )]
525    }
526
527    fn setup(&mut self, indicators: &HashMap<String, Vec<Option<f64>>>) {
528        if let Some(v) = indicators.get(&self.lower_key) {
529            self.lower_slot.set(v);
530        }
531        if let Some(v) = indicators.get(&self.middle_key) {
532            self.middle_slot.set(v);
533        }
534        if let Some(v) = indicators.get(&self.upper_key) {
535            self.upper_slot.set(v);
536        }
537    }
538
539    fn warmup_period(&self) -> usize {
540        self.period
541    }
542
543    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
544        let candle = ctx.current_candle();
545        let close = candle.close;
546        let i = ctx.index;
547
548        // SAFETY: see SmaCrossover::on_candle.
549        let lower_vals =
550            unsafe { self.lower_slot.get() }.or_else(|| ctx.indicators.get(&self.lower_key));
551        let middle_vals =
552            unsafe { self.middle_slot.get() }.or_else(|| ctx.indicators.get(&self.middle_key));
553        let upper_vals =
554            unsafe { self.upper_slot.get() }.or_else(|| ctx.indicators.get(&self.upper_key));
555        let (Some(lower_vals), Some(middle_vals), Some(upper_vals)) =
556            (lower_vals, middle_vals, upper_vals)
557        else {
558            return Signal::hold();
559        };
560
561        let get = |vals: &Vec<Option<f64>>, idx: usize| vals.get(idx).and_then(|&v| v);
562        let (Some(lower_val), Some(middle_val), Some(upper_val)) =
563            (get(lower_vals, i), get(middle_vals, i), get(upper_vals, i))
564        else {
565            return Signal::hold();
566        };
567
568        // Long entry: price at or below lower band
569        if close <= lower_val && !ctx.has_position() {
570            return Signal::long(candle.timestamp, close)
571                .with_reason("Price at lower Bollinger Band");
572        }
573
574        // Long exit
575        if ctx.is_long() {
576            let exit_level = if self.exit_at_middle {
577                middle_val
578            } else {
579                upper_val
580            };
581            if close >= exit_level {
582                return Signal::exit(candle.timestamp, close).with_reason(format!(
583                    "Price reached {} Bollinger Band",
584                    if self.exit_at_middle {
585                        "middle"
586                    } else {
587                        "upper"
588                    }
589                ));
590            }
591        }
592
593        // Short entry: price at or above upper band
594        if close >= upper_val && !ctx.has_position() {
595            return Signal::short(candle.timestamp, close)
596                .with_reason("Price at upper Bollinger Band");
597        }
598
599        // Short exit
600        if ctx.is_short() {
601            let exit_level = if self.exit_at_middle {
602                middle_val
603            } else {
604                lower_val
605            };
606            if close <= exit_level {
607                return Signal::exit(candle.timestamp, close).with_reason(format!(
608                    "Price reached {} Bollinger Band",
609                    if self.exit_at_middle {
610                        "middle"
611                    } else {
612                        "lower"
613                    }
614                ));
615            }
616        }
617
618        Signal::hold()
619    }
620}
621
622/// SuperTrend Following Strategy
623///
624/// Goes long when SuperTrend turns bullish (uptrend).
625/// Emits short signals when SuperTrend turns bearish (execution gated by
626/// [`BacktestConfig::allow_short`](crate::backtesting::BacktestConfig)).
627#[derive(Debug, Clone)]
628pub struct SuperTrendFollow {
629    /// ATR period
630    pub period: usize,
631    /// ATR multiplier
632    pub multiplier: f64,
633    uptrend_key: String,
634    uptrend_slot: IndicatorSlot,
635}
636
637impl SuperTrendFollow {
638    /// Create a new SuperTrend following strategy
639    pub fn new(period: usize, multiplier: f64) -> Self {
640        Self {
641            period,
642            multiplier,
643            uptrend_key: format!("supertrend_uptrend_{period}_{multiplier}"),
644            uptrend_slot: IndicatorSlot::default(),
645        }
646    }
647}
648
649impl Default for SuperTrendFollow {
650    fn default() -> Self {
651        Self::new(10, 3.0)
652    }
653}
654
655impl Strategy for SuperTrendFollow {
656    fn name(&self) -> &str {
657        "SuperTrend Follow"
658    }
659
660    fn required_indicators(&self) -> Vec<(String, Indicator)> {
661        vec![(
662            "supertrend".to_string(),
663            Indicator::Supertrend {
664                period: self.period,
665                multiplier: self.multiplier,
666            },
667        )]
668    }
669
670    fn setup(&mut self, indicators: &HashMap<String, Vec<Option<f64>>>) {
671        if let Some(v) = indicators.get(&self.uptrend_key) {
672            self.uptrend_slot.set(v);
673        }
674    }
675
676    fn warmup_period(&self) -> usize {
677        self.period + 1
678    }
679
680    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
681        let candle = ctx.current_candle();
682        let i = ctx.index;
683
684        // SAFETY: see SmaCrossover::on_candle.
685        let vals =
686            unsafe { self.uptrend_slot.get() }.or_else(|| ctx.indicators.get(&self.uptrend_key));
687        let Some(vals) = vals else {
688            return Signal::hold();
689        };
690        let get = |idx: usize| vals.get(idx).and_then(|&v| v);
691        let (Some(now), Some(prev)) = (get(i), if i > 0 { get(i - 1) } else { None }) else {
692            return Signal::hold();
693        };
694
695        let is_uptrend = now > 0.5;
696        let was_uptrend = prev > 0.5;
697
698        // Trend changed to bullish
699        if is_uptrend && !was_uptrend {
700            if ctx.is_short() {
701                return Signal::exit(candle.timestamp, candle.close)
702                    .with_reason("SuperTrend turned bullish - close short");
703            }
704            if !ctx.has_position() {
705                return Signal::long(candle.timestamp, candle.close)
706                    .with_reason("SuperTrend turned bullish");
707            }
708        }
709
710        // Trend changed to bearish
711        if !is_uptrend && was_uptrend {
712            if ctx.is_long() {
713                return Signal::exit(candle.timestamp, candle.close)
714                    .with_reason("SuperTrend turned bearish - close long");
715            }
716            if !ctx.has_position() {
717                return Signal::short(candle.timestamp, candle.close)
718                    .with_reason("SuperTrend turned bearish");
719            }
720        }
721
722        Signal::hold()
723    }
724}
725
726/// Donchian Channel Breakout Strategy
727///
728/// Goes long when price breaks above upper channel (new high).
729/// Exits when price breaks below lower channel (new low).
730/// Emits short signals on downward breakouts (execution gated by
731/// [`BacktestConfig::allow_short`](crate::backtesting::BacktestConfig)).
732#[derive(Debug, Clone)]
733pub struct DonchianBreakout {
734    /// Channel period
735    pub period: usize,
736    /// Use middle channel for exit (true) or opposite channel (false)
737    pub exit_at_middle: bool,
738    upper_key: String,
739    middle_key: String,
740    lower_key: String,
741    upper_slot: IndicatorSlot,
742    middle_slot: IndicatorSlot,
743    lower_slot: IndicatorSlot,
744}
745
746impl DonchianBreakout {
747    /// Create a new Donchian breakout strategy
748    pub fn new(period: usize) -> Self {
749        Self {
750            period,
751            exit_at_middle: true,
752            upper_key: format!("donchian_upper_{period}"),
753            middle_key: format!("donchian_middle_{period}"),
754            lower_key: format!("donchian_lower_{period}"),
755            upper_slot: IndicatorSlot::default(),
756            middle_slot: IndicatorSlot::default(),
757            lower_slot: IndicatorSlot::default(),
758        }
759    }
760
761    /// Set exit at middle channel
762    pub fn exit_at_middle(mut self, at_middle: bool) -> Self {
763        self.exit_at_middle = at_middle;
764        self
765    }
766}
767
768impl Default for DonchianBreakout {
769    fn default() -> Self {
770        Self::new(20)
771    }
772}
773
774impl Strategy for DonchianBreakout {
775    fn name(&self) -> &str {
776        "Donchian Breakout"
777    }
778
779    fn required_indicators(&self) -> Vec<(String, Indicator)> {
780        vec![(
781            "donchian".to_string(),
782            Indicator::DonchianChannels(self.period),
783        )]
784    }
785
786    fn setup(&mut self, indicators: &HashMap<String, Vec<Option<f64>>>) {
787        if let Some(v) = indicators.get(&self.upper_key) {
788            self.upper_slot.set(v);
789        }
790        if let Some(v) = indicators.get(&self.middle_key) {
791            self.middle_slot.set(v);
792        }
793        if let Some(v) = indicators.get(&self.lower_key) {
794            self.lower_slot.set(v);
795        }
796    }
797
798    fn warmup_period(&self) -> usize {
799        self.period
800    }
801
802    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
803        let candle = ctx.current_candle();
804        let close = candle.close;
805        let i = ctx.index;
806
807        // SAFETY: see SmaCrossover::on_candle.
808        let upper_vals =
809            unsafe { self.upper_slot.get() }.or_else(|| ctx.indicators.get(&self.upper_key));
810        let middle_vals =
811            unsafe { self.middle_slot.get() }.or_else(|| ctx.indicators.get(&self.middle_key));
812        let lower_vals =
813            unsafe { self.lower_slot.get() }.or_else(|| ctx.indicators.get(&self.lower_key));
814        let (Some(upper_vals), Some(middle_vals), Some(lower_vals)) =
815            (upper_vals, middle_vals, lower_vals)
816        else {
817            return Signal::hold();
818        };
819        let get = |vals: &Vec<Option<f64>>, idx: usize| vals.get(idx).and_then(|&v| v);
820        let (Some(_upper_val), Some(middle_val), Some(_lower_val)) =
821            (get(upper_vals, i), get(middle_vals, i), get(lower_vals, i))
822        else {
823            return Signal::hold();
824        };
825        let prev_upper = if i > 0 { get(upper_vals, i - 1) } else { None };
826        let prev_lower = if i > 0 { get(lower_vals, i - 1) } else { None };
827
828        // Breakout above the *previous* bar's upper channel level → go long.
829        // Using the lagged level rather than the current bar's channel prevents
830        // look-ahead bias: the current bar's Donchian high is computed using the
831        // close of that same bar, so comparing `close > current_upper` would
832        // trivially never trigger (the close can equal but not exceed the max
833        // of the window it belongs to).  The lagged level is the natural
834        // reference point for a confirmed breakout signal.
835        if let Some(prev_up) = prev_upper
836            && close > prev_up
837            && !ctx.has_position()
838        {
839            return Signal::long(candle.timestamp, close)
840                .with_reason("Donchian upper channel breakout");
841        }
842
843        // Breakdown below the *previous* bar's lower channel level (same
844        // lagged-reference rationale as the upper channel breakout above).
845        if let Some(prev_low) = prev_lower
846            && close < prev_low
847        {
848            if ctx.is_long() {
849                return Signal::exit(candle.timestamp, close)
850                    .with_reason("Donchian lower channel breakdown - close long");
851            }
852            if !ctx.has_position() {
853                return Signal::short(candle.timestamp, close)
854                    .with_reason("Donchian lower channel breakdown");
855            }
856        }
857
858        // Exit long at middle
859        if ctx.is_long() && self.exit_at_middle && close <= middle_val {
860            return Signal::exit(candle.timestamp, close)
861                .with_reason("Price reached Donchian middle channel");
862        }
863
864        // Exit short at middle
865        if ctx.is_short() && self.exit_at_middle && close >= middle_val {
866            return Signal::exit(candle.timestamp, close)
867                .with_reason("Price reached Donchian middle channel");
868        }
869
870        Signal::hold()
871    }
872}
873
874#[cfg(test)]
875mod tests {
876    use super::*;
877
878    #[test]
879    fn test_sma_crossover_default() {
880        let s = SmaCrossover::default();
881        assert_eq!(s.fast_period, 10);
882        assert_eq!(s.slow_period, 20);
883    }
884
885    #[test]
886    fn test_sma_crossover_custom() {
887        let s = SmaCrossover::new(5, 15);
888        assert_eq!(s.fast_period, 5);
889        assert_eq!(s.slow_period, 15);
890    }
891
892    #[test]
893    fn test_rsi_default() {
894        let s = RsiReversal::default();
895        assert_eq!(s.period, 14);
896        assert!((s.oversold - 30.0).abs() < 0.01);
897        assert!((s.overbought - 70.0).abs() < 0.01);
898    }
899
900    #[test]
901    fn test_rsi_with_thresholds() {
902        let s = RsiReversal::new(10).with_thresholds(25.0, 75.0);
903        assert_eq!(s.period, 10);
904        assert!((s.oversold - 25.0).abs() < 0.01);
905        assert!((s.overbought - 75.0).abs() < 0.01);
906    }
907
908    #[test]
909    fn test_macd_default() {
910        let s = MacdSignal::default();
911        assert_eq!(s.fast, 12);
912        assert_eq!(s.slow, 26);
913        assert_eq!(s.signal, 9);
914    }
915
916    #[test]
917    fn test_bollinger_default() {
918        let s = BollingerMeanReversion::default();
919        assert_eq!(s.period, 20);
920        assert!((s.std_dev - 2.0).abs() < 0.01);
921    }
922
923    #[test]
924    fn test_supertrend_default() {
925        let s = SuperTrendFollow::default();
926        assert_eq!(s.period, 10);
927        assert!((s.multiplier - 3.0).abs() < 0.01);
928    }
929
930    #[test]
931    fn test_donchian_default() {
932        let s = DonchianBreakout::default();
933        assert_eq!(s.period, 20);
934        assert!(s.exit_at_middle);
935    }
936
937    #[test]
938    fn test_strategy_names() {
939        assert_eq!(SmaCrossover::default().name(), "SMA Crossover");
940        assert_eq!(RsiReversal::default().name(), "RSI Reversal");
941        assert_eq!(MacdSignal::default().name(), "MACD Signal");
942        assert_eq!(
943            BollingerMeanReversion::default().name(),
944            "Bollinger Mean Reversion"
945        );
946        assert_eq!(SuperTrendFollow::default().name(), "SuperTrend Follow");
947        assert_eq!(DonchianBreakout::default().name(), "Donchian Breakout");
948    }
949
950    #[test]
951    fn test_required_indicators() {
952        let sma = SmaCrossover::new(5, 10);
953        let indicators = sma.required_indicators();
954        assert_eq!(indicators.len(), 2);
955        assert_eq!(indicators[0].0, "sma_5");
956        assert_eq!(indicators[1].0, "sma_10");
957
958        let rsi = RsiReversal::new(14);
959        let indicators = rsi.required_indicators();
960        assert_eq!(indicators.len(), 1);
961        assert_eq!(indicators[0].0, "rsi_14");
962    }
963}