Skip to main content

finance_query/backtesting/strategy/
mod.rs

1//! Strategy trait and context for building trading strategies.
2//!
3//! This module provides the core `Strategy` trait and `StrategyContext` for
4//! implementing custom trading strategies, as well as pre-built strategies
5//! and a fluent builder API.
6//!
7//! # Building Custom Strategies
8//!
9//! Use the `StrategyBuilder` for creating strategies with conditions:
10//!
11//! ```ignore
12//! use finance_query::backtesting::strategy::StrategyBuilder;
13//! use finance_query::backtesting::refs::*;
14//! use finance_query::backtesting::condition::*;
15//!
16//! let strategy = StrategyBuilder::new("My Strategy")
17//!     .entry(rsi(14).crosses_below(30.0))
18//!     .exit(rsi(14).crosses_above(70.0).or(stop_loss(0.05)))
19//!     .build();
20//! ```
21
22mod builder;
23mod ensemble;
24pub mod prebuilt;
25
26use std::collections::HashMap;
27
28use crate::backtesting::condition::HtfIndicatorSpec;
29use crate::indicators::Indicator;
30use crate::models::chart::Candle;
31
32use super::position::{Position, PositionSide};
33use super::signal::Signal;
34
35// Re-export builder
36pub use builder::{CustomStrategy, StrategyBuilder};
37
38// Re-export ensemble
39pub use ensemble::{EnsembleMode, EnsembleStrategy};
40
41// Re-export prebuilt strategies
42pub use prebuilt::{
43    BollingerMeanReversion, DonchianBreakout, MacdSignal, RsiReversal, SmaCrossover,
44    SuperTrendFollow,
45};
46
47/// Context passed to strategy on each candle.
48///
49/// Provides access to historical data, current position, and pre-computed indicators.
50#[non_exhaustive]
51pub struct StrategyContext<'a> {
52    /// All candles up to and including current
53    pub candles: &'a [Candle],
54
55    /// Current candle index (0-based)
56    pub index: usize,
57
58    /// Current position (if any)
59    pub position: Option<&'a Position>,
60
61    /// Current portfolio equity
62    pub equity: f64,
63
64    /// Pre-computed indicator values (keyed by indicator name)
65    pub indicators: &'a HashMap<String, Vec<Option<f64>>>,
66}
67
68impl<'a> StrategyContext<'a> {
69    /// Get current candle
70    pub fn current_candle(&self) -> &Candle {
71        &self.candles[self.index]
72    }
73
74    /// Get previous candle (None if at start)
75    pub fn previous_candle(&self) -> Option<&Candle> {
76        if self.index > 0 {
77            Some(&self.candles[self.index - 1])
78        } else {
79            None
80        }
81    }
82
83    /// Get candle at specific index (None if out of bounds)
84    pub fn candle_at(&self, index: usize) -> Option<&Candle> {
85        self.candles.get(index)
86    }
87
88    /// Get indicator value at current index
89    pub fn indicator(&self, name: &str) -> Option<f64> {
90        self.indicators
91            .get(name)
92            .and_then(|v| v.get(self.index))
93            .and_then(|&v| v)
94    }
95
96    /// Get indicator value at specific index
97    pub fn indicator_at(&self, name: &str, index: usize) -> Option<f64> {
98        self.indicators
99            .get(name)
100            .and_then(|v| v.get(index))
101            .and_then(|&v| v)
102    }
103
104    /// Get indicator value at previous index
105    pub fn indicator_prev(&self, name: &str) -> Option<f64> {
106        if self.index > 0 {
107            self.indicator_at(name, self.index - 1)
108        } else {
109            None
110        }
111    }
112
113    /// Check if we have a position
114    pub fn has_position(&self) -> bool {
115        self.position.is_some()
116    }
117
118    /// Check if we have a long position
119    pub fn is_long(&self) -> bool {
120        self.position
121            .map(|p| matches!(p.side, PositionSide::Long))
122            .unwrap_or(false)
123    }
124
125    /// Check if we have a short position
126    pub fn is_short(&self) -> bool {
127        self.position
128            .map(|p| matches!(p.side, PositionSide::Short))
129            .unwrap_or(false)
130    }
131
132    /// Get current close price
133    pub fn close(&self) -> f64 {
134        self.current_candle().close
135    }
136
137    /// Get current open price
138    pub fn open(&self) -> f64 {
139        self.current_candle().open
140    }
141
142    /// Get current high price
143    pub fn high(&self) -> f64 {
144        self.current_candle().high
145    }
146
147    /// Get current low price
148    pub fn low(&self) -> f64 {
149        self.current_candle().low
150    }
151
152    /// Get current volume
153    pub fn volume(&self) -> i64 {
154        self.current_candle().volume
155    }
156
157    /// Get current timestamp
158    pub fn timestamp(&self) -> i64 {
159        self.current_candle().timestamp
160    }
161
162    /// Create a Long signal from the current candle's timestamp and close price.
163    pub fn signal_long(&self) -> Signal {
164        Signal::long(self.timestamp(), self.close())
165    }
166
167    /// Create a Short signal from the current candle's timestamp and close price.
168    pub fn signal_short(&self) -> Signal {
169        Signal::short(self.timestamp(), self.close())
170    }
171
172    /// Create an Exit signal from the current candle's timestamp and close price.
173    pub fn signal_exit(&self) -> Signal {
174        Signal::exit(self.timestamp(), self.close())
175    }
176
177    /// Check if crossover occurred (fast crosses above slow)
178    pub fn crossed_above(&self, fast_name: &str, slow_name: &str) -> bool {
179        let fast_now = self.indicator(fast_name);
180        let slow_now = self.indicator(slow_name);
181        let fast_prev = self.indicator_prev(fast_name);
182        let slow_prev = self.indicator_prev(slow_name);
183
184        match (fast_now, slow_now, fast_prev, slow_prev) {
185            (Some(f), Some(s), Some(fp), Some(sp)) => fp < sp && f > s, // Fixed: changed <= to <
186            _ => false,
187        }
188    }
189
190    /// Check if crossover occurred (fast crosses below slow)
191    pub fn crossed_below(&self, fast_name: &str, slow_name: &str) -> bool {
192        let fast_now = self.indicator(fast_name);
193        let slow_now = self.indicator(slow_name);
194        let fast_prev = self.indicator_prev(fast_name);
195        let slow_prev = self.indicator_prev(slow_name);
196
197        match (fast_now, slow_now, fast_prev, slow_prev) {
198            (Some(f), Some(s), Some(fp), Some(sp)) => fp > sp && f < s, // Fixed: changed >= to >
199            _ => false,
200        }
201    }
202
203    /// Check if indicator crossed above a threshold.
204    ///
205    /// Returns `true` when `prev <= threshold` **and** `current > threshold`.
206    /// The inclusive lower bound (`<=`) means a signal fires even when the
207    /// previous bar sat exactly on the threshold, which is the conventional
208    /// "crosses above" definition.  This is intentionally asymmetric with the
209    /// strict crossover check in [`crossed_above`](Self::crossed_above) where
210    /// both sides use strict inequalities — threshold crossings and
211    /// indicator-vs-indicator crossings have different semantics.
212    pub fn indicator_crossed_above(&self, name: &str, threshold: f64) -> bool {
213        let now = self.indicator(name);
214        let prev = self.indicator_prev(name);
215
216        match (now, prev) {
217            (Some(n), Some(p)) => p <= threshold && n > threshold,
218            _ => false,
219        }
220    }
221
222    /// Check if indicator crossed below a threshold.
223    ///
224    /// Returns `true` when `prev >= threshold` **and** `current < threshold`.
225    /// See [`indicator_crossed_above`](Self::indicator_crossed_above) for the
226    /// rationale behind the inclusive/exclusive choice on each side.
227    pub fn indicator_crossed_below(&self, name: &str, threshold: f64) -> bool {
228        let now = self.indicator(name);
229        let prev = self.indicator_prev(name);
230
231        match (now, prev) {
232            (Some(n), Some(p)) => p >= threshold && n < threshold,
233            _ => false,
234        }
235    }
236}
237
238/// Core strategy trait - implement this for custom strategies.
239///
240/// # Example
241///
242/// ```ignore
243/// use finance_query::backtesting::{Strategy, StrategyContext, Signal};
244/// use finance_query::indicators::Indicator;
245///
246/// struct MyStrategy {
247///     sma_period: usize,
248/// }
249///
250/// impl Strategy for MyStrategy {
251///     fn name(&self) -> &str {
252///         "My Custom Strategy"
253///     }
254///
255///     fn required_indicators(&self) -> Vec<(String, Indicator)> {
256///         vec![
257///             (format!("sma_{}", self.sma_period), Indicator::Sma(self.sma_period)),
258///         ]
259///     }
260///
261///     fn on_candle(&self, ctx: &StrategyContext) -> Signal {
262///         let sma = ctx.indicator(&format!("sma_{}", self.sma_period));
263///         let close = ctx.close();
264///
265///         match sma {
266///             Some(sma_val) if close > sma_val && !ctx.has_position() => {
267///                 Signal::long(ctx.timestamp(), close)
268///             }
269///             Some(sma_val) if close < sma_val && ctx.is_long() => {
270///                 Signal::exit(ctx.timestamp(), close)
271///             }
272///             _ => Signal::hold(),
273///         }
274///     }
275/// }
276/// ```
277pub trait Strategy: Send + Sync {
278    /// Strategy name (for reporting)
279    fn name(&self) -> &str;
280
281    /// Required indicators this strategy needs.
282    ///
283    /// Returns list of (indicator_name, Indicator) pairs.
284    /// The engine will pre-compute these and make them available via `StrategyContext::indicator()`.
285    fn required_indicators(&self) -> Vec<(String, Indicator)>;
286
287    /// Higher-timeframe indicators required by this strategy.
288    ///
289    /// The engine resamples candles to each unique interval, computes the
290    /// listed indicators on the resampled data, and stores stretched
291    /// (base-timeframe-length) arrays in `StrategyContext::indicators` under
292    /// the `htf_key` names. Strategies built with [`StrategyBuilder`] implement
293    /// this automatically; raw [`Strategy`] implementations that use HTF
294    /// conditions should override this to avoid the O(n²) dynamic fallback.
295    ///
296    /// [`StrategyBuilder`]: crate::backtesting::strategy::StrategyBuilder
297    fn htf_requirements(&self) -> Vec<HtfIndicatorSpec> {
298        vec![]
299    }
300
301    /// Called on each candle to generate a signal.
302    ///
303    /// Return `Signal::hold()` for no action, `Signal::long()` to enter long,
304    /// `Signal::short()` to enter short, or `Signal::exit()` to close position.
305    fn on_candle(&self, ctx: &StrategyContext) -> Signal;
306
307    /// Optional: minimum candles required before strategy can generate signals.
308    /// Default is 1 (strategy can run from first candle).
309    fn warmup_period(&self) -> usize {
310        1
311    }
312}
313
314impl Strategy for Box<dyn Strategy> {
315    fn name(&self) -> &str {
316        (**self).name()
317    }
318    fn required_indicators(&self) -> Vec<(String, Indicator)> {
319        (**self).required_indicators()
320    }
321    fn htf_requirements(&self) -> Vec<HtfIndicatorSpec> {
322        (**self).htf_requirements()
323    }
324    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
325        (**self).on_candle(ctx)
326    }
327    fn warmup_period(&self) -> usize {
328        (**self).warmup_period()
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    struct TestStrategy;
337
338    impl Strategy for TestStrategy {
339        fn name(&self) -> &str {
340            "Test Strategy"
341        }
342
343        fn required_indicators(&self) -> Vec<(String, Indicator)> {
344            vec![("sma_10".to_string(), Indicator::Sma(10))]
345        }
346
347        fn on_candle(&self, ctx: &StrategyContext) -> Signal {
348            if ctx.index == 5 {
349                Signal::long(ctx.timestamp(), ctx.close())
350            } else {
351                Signal::hold()
352            }
353        }
354    }
355
356    #[test]
357    fn test_strategy_trait() {
358        let strategy = TestStrategy;
359        assert_eq!(strategy.name(), "Test Strategy");
360        assert_eq!(strategy.required_indicators().len(), 1);
361        assert_eq!(strategy.warmup_period(), 1);
362    }
363
364    #[test]
365    fn test_context_crossover_detection() {
366        let candles = vec![
367            Candle {
368                timestamp: 1,
369                open: 100.0,
370                high: 101.0,
371                low: 99.0,
372                close: 100.0,
373                volume: 1000,
374                adj_close: None,
375            },
376            Candle {
377                timestamp: 2,
378                open: 100.0,
379                high: 102.0,
380                low: 99.0,
381                close: 101.0,
382                volume: 1000,
383                adj_close: None,
384            },
385        ];
386
387        let mut indicators = HashMap::new();
388        indicators.insert("fast".to_string(), vec![Some(9.0), Some(11.0)]);
389        indicators.insert("slow".to_string(), vec![Some(10.0), Some(10.0)]);
390
391        let ctx = StrategyContext {
392            candles: &candles,
393            index: 1,
394            position: None,
395            equity: 10000.0,
396            indicators: &indicators,
397        };
398
399        // fast was 9 (below slow 10), now 11 (above slow 10) -> crossed above
400        assert!(ctx.crossed_above("fast", "slow"));
401        assert!(!ctx.crossed_below("fast", "slow"));
402    }
403}