Skip to main content

finance_query/backtesting/strategy/
mod.rs

1//! Strategy trait and context for building trading strategies.
2//!
3//! This module provides the core `Strategy` trait and `StrategyContext` for
4//! implementing custom trading strategies, as well as pre-built strategies
5//! and a fluent builder API.
6//!
7//! # Building Custom Strategies
8//!
9//! Use the `StrategyBuilder` for creating strategies with conditions:
10//!
11//! ```ignore
12//! use finance_query::backtesting::strategy::StrategyBuilder;
13//! use finance_query::backtesting::refs::*;
14//! use finance_query::backtesting::condition::*;
15//!
16//! let strategy = StrategyBuilder::new("My Strategy")
17//!     .entry(rsi(14).crosses_below(30.0))
18//!     .exit(rsi(14).crosses_above(70.0).or(stop_loss(0.05)))
19//!     .build();
20//! ```
21
22mod builder;
23pub mod prebuilt;
24
25use std::collections::HashMap;
26
27use crate::indicators::Indicator;
28use crate::models::chart::Candle;
29
30use super::position::{Position, PositionSide};
31use super::signal::Signal;
32
33// Re-export builder
34pub use builder::{CustomStrategy, StrategyBuilder};
35
36// Re-export prebuilt strategies
37pub use prebuilt::{
38    BollingerMeanReversion, DonchianBreakout, MacdSignal, RsiReversal, SmaCrossover,
39    SuperTrendFollow,
40};
41
42/// Context passed to strategy on each candle.
43///
44/// Provides access to historical data, current position, and pre-computed indicators.
45#[non_exhaustive]
46pub struct StrategyContext<'a> {
47    /// All candles up to and including current
48    pub candles: &'a [Candle],
49
50    /// Current candle index (0-based)
51    pub index: usize,
52
53    /// Current position (if any)
54    pub position: Option<&'a Position>,
55
56    /// Current portfolio equity
57    pub equity: f64,
58
59    /// Pre-computed indicator values (keyed by indicator name)
60    pub indicators: &'a HashMap<String, Vec<Option<f64>>>,
61}
62
63impl<'a> StrategyContext<'a> {
64    /// Get current candle
65    pub fn current_candle(&self) -> &Candle {
66        &self.candles[self.index]
67    }
68
69    /// Get previous candle (None if at start)
70    pub fn previous_candle(&self) -> Option<&Candle> {
71        if self.index > 0 {
72            Some(&self.candles[self.index - 1])
73        } else {
74            None
75        }
76    }
77
78    /// Get candle at specific index (None if out of bounds)
79    pub fn candle_at(&self, index: usize) -> Option<&Candle> {
80        self.candles.get(index)
81    }
82
83    /// Get indicator value at current index
84    pub fn indicator(&self, name: &str) -> Option<f64> {
85        self.indicators
86            .get(name)
87            .and_then(|v| v.get(self.index))
88            .and_then(|&v| v)
89    }
90
91    /// Get indicator value at specific index
92    pub fn indicator_at(&self, name: &str, index: usize) -> Option<f64> {
93        self.indicators
94            .get(name)
95            .and_then(|v| v.get(index))
96            .and_then(|&v| v)
97    }
98
99    /// Get indicator value at previous index
100    pub fn indicator_prev(&self, name: &str) -> Option<f64> {
101        if self.index > 0 {
102            self.indicator_at(name, self.index - 1)
103        } else {
104            None
105        }
106    }
107
108    /// Check if we have a position
109    pub fn has_position(&self) -> bool {
110        self.position.is_some()
111    }
112
113    /// Check if we have a long position
114    pub fn is_long(&self) -> bool {
115        self.position
116            .map(|p| matches!(p.side, PositionSide::Long))
117            .unwrap_or(false)
118    }
119
120    /// Check if we have a short position
121    pub fn is_short(&self) -> bool {
122        self.position
123            .map(|p| matches!(p.side, PositionSide::Short))
124            .unwrap_or(false)
125    }
126
127    /// Get current close price
128    pub fn close(&self) -> f64 {
129        self.current_candle().close
130    }
131
132    /// Get current open price
133    pub fn open(&self) -> f64 {
134        self.current_candle().open
135    }
136
137    /// Get current high price
138    pub fn high(&self) -> f64 {
139        self.current_candle().high
140    }
141
142    /// Get current low price
143    pub fn low(&self) -> f64 {
144        self.current_candle().low
145    }
146
147    /// Get current volume
148    pub fn volume(&self) -> i64 {
149        self.current_candle().volume
150    }
151
152    /// Get current timestamp
153    pub fn timestamp(&self) -> i64 {
154        self.current_candle().timestamp
155    }
156
157    /// Check if crossover occurred (fast crosses above slow)
158    pub fn crossed_above(&self, fast_name: &str, slow_name: &str) -> bool {
159        let fast_now = self.indicator(fast_name);
160        let slow_now = self.indicator(slow_name);
161        let fast_prev = self.indicator_prev(fast_name);
162        let slow_prev = self.indicator_prev(slow_name);
163
164        match (fast_now, slow_now, fast_prev, slow_prev) {
165            (Some(f), Some(s), Some(fp), Some(sp)) => fp < sp && f > s, // Fixed: changed <= to <
166            _ => false,
167        }
168    }
169
170    /// Check if crossover occurred (fast crosses below slow)
171    pub fn crossed_below(&self, fast_name: &str, slow_name: &str) -> bool {
172        let fast_now = self.indicator(fast_name);
173        let slow_now = self.indicator(slow_name);
174        let fast_prev = self.indicator_prev(fast_name);
175        let slow_prev = self.indicator_prev(slow_name);
176
177        match (fast_now, slow_now, fast_prev, slow_prev) {
178            (Some(f), Some(s), Some(fp), Some(sp)) => fp > sp && f < s, // Fixed: changed >= to >
179            _ => false,
180        }
181    }
182
183    /// Check if indicator crossed above a threshold
184    pub fn indicator_crossed_above(&self, name: &str, threshold: f64) -> bool {
185        let now = self.indicator(name);
186        let prev = self.indicator_prev(name);
187
188        match (now, prev) {
189            (Some(n), Some(p)) => p <= threshold && n > threshold,
190            _ => false,
191        }
192    }
193
194    /// Check if indicator crossed below a threshold
195    pub fn indicator_crossed_below(&self, name: &str, threshold: f64) -> bool {
196        let now = self.indicator(name);
197        let prev = self.indicator_prev(name);
198
199        match (now, prev) {
200            (Some(n), Some(p)) => p >= threshold && n < threshold,
201            _ => false,
202        }
203    }
204}
205
206/// Core strategy trait - implement this for custom strategies.
207///
208/// # Example
209///
210/// ```ignore
211/// use finance_query::backtesting::{Strategy, StrategyContext, Signal};
212/// use finance_query::indicators::Indicator;
213///
214/// struct MyStrategy {
215///     sma_period: usize,
216/// }
217///
218/// impl Strategy for MyStrategy {
219///     fn name(&self) -> &str {
220///         "My Custom Strategy"
221///     }
222///
223///     fn required_indicators(&self) -> Vec<(String, Indicator)> {
224///         vec![
225///             (format!("sma_{}", self.sma_period), Indicator::Sma(self.sma_period)),
226///         ]
227///     }
228///
229///     fn on_candle(&self, ctx: &StrategyContext) -> Signal {
230///         let sma = ctx.indicator(&format!("sma_{}", self.sma_period));
231///         let close = ctx.close();
232///
233///         match sma {
234///             Some(sma_val) if close > sma_val && !ctx.has_position() => {
235///                 Signal::long(ctx.timestamp(), close)
236///             }
237///             Some(sma_val) if close < sma_val && ctx.is_long() => {
238///                 Signal::exit(ctx.timestamp(), close)
239///             }
240///             _ => Signal::hold(),
241///         }
242///     }
243/// }
244/// ```
245pub trait Strategy: Send + Sync {
246    /// Strategy name (for reporting)
247    fn name(&self) -> &str;
248
249    /// Required indicators this strategy needs.
250    ///
251    /// Returns list of (indicator_name, Indicator) pairs.
252    /// The engine will pre-compute these and make them available via `StrategyContext::indicator()`.
253    fn required_indicators(&self) -> Vec<(String, Indicator)>;
254
255    /// Called on each candle to generate a signal.
256    ///
257    /// Return `Signal::hold()` for no action, `Signal::long()` to enter long,
258    /// `Signal::short()` to enter short, or `Signal::exit()` to close position.
259    fn on_candle(&self, ctx: &StrategyContext) -> Signal;
260
261    /// Optional: minimum candles required before strategy can generate signals.
262    /// Default is 1 (strategy can run from first candle).
263    fn warmup_period(&self) -> usize {
264        1
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    struct TestStrategy;
273
274    impl Strategy for TestStrategy {
275        fn name(&self) -> &str {
276            "Test Strategy"
277        }
278
279        fn required_indicators(&self) -> Vec<(String, Indicator)> {
280            vec![("sma_10".to_string(), Indicator::Sma(10))]
281        }
282
283        fn on_candle(&self, ctx: &StrategyContext) -> Signal {
284            if ctx.index == 5 {
285                Signal::long(ctx.timestamp(), ctx.close())
286            } else {
287                Signal::hold()
288            }
289        }
290    }
291
292    #[test]
293    fn test_strategy_trait() {
294        let strategy = TestStrategy;
295        assert_eq!(strategy.name(), "Test Strategy");
296        assert_eq!(strategy.required_indicators().len(), 1);
297        assert_eq!(strategy.warmup_period(), 1);
298    }
299
300    #[test]
301    fn test_context_crossover_detection() {
302        let candles = vec![
303            Candle {
304                timestamp: 1,
305                open: 100.0,
306                high: 101.0,
307                low: 99.0,
308                close: 100.0,
309                volume: 1000,
310                adj_close: None,
311            },
312            Candle {
313                timestamp: 2,
314                open: 100.0,
315                high: 102.0,
316                low: 99.0,
317                close: 101.0,
318                volume: 1000,
319                adj_close: None,
320            },
321        ];
322
323        let mut indicators = HashMap::new();
324        indicators.insert("fast".to_string(), vec![Some(9.0), Some(11.0)]);
325        indicators.insert("slow".to_string(), vec![Some(10.0), Some(10.0)]);
326
327        let ctx = StrategyContext {
328            candles: &candles,
329            index: 1,
330            position: None,
331            equity: 10000.0,
332            indicators: &indicators,
333        };
334
335        // fast was 9 (below slow 10), now 11 (above slow 10) -> crossed above
336        assert!(ctx.crossed_above("fast", "slow"));
337        assert!(!ctx.crossed_below("fast", "slow"));
338    }
339}