finance_query/backtesting/strategy/mod.rs
1//! Strategy trait and context for building trading strategies.
2//!
3//! This module provides the core `Strategy` trait and `StrategyContext` for
4//! implementing custom trading strategies, as well as pre-built strategies
5//! and a fluent builder API.
6//!
7//! # Building Custom Strategies
8//!
9//! Use the `StrategyBuilder` for creating strategies with conditions:
10//!
11//! ```ignore
12//! use finance_query::backtesting::strategy::StrategyBuilder;
13//! use finance_query::backtesting::refs::*;
14//! use finance_query::backtesting::condition::*;
15//!
16//! let strategy = StrategyBuilder::new("My Strategy")
17//! .entry(rsi(14).crosses_below(30.0))
18//! .exit(rsi(14).crosses_above(70.0).or(stop_loss(0.05)))
19//! .build();
20//! ```
21
22mod builder;
23mod ensemble;
24pub mod prebuilt;
25
26use std::collections::HashMap;
27
28use crate::backtesting::condition::HtfIndicatorSpec;
29use crate::indicators::Indicator;
30use crate::models::chart::Candle;
31
32use super::position::{Position, PositionSide};
33use super::signal::Signal;
34
35// Re-export builder
36pub use builder::{CustomStrategy, StrategyBuilder};
37
38// Re-export ensemble
39pub use ensemble::{EnsembleMode, EnsembleStrategy};
40
41// Re-export prebuilt strategies
42pub use prebuilt::{
43 BollingerMeanReversion, DonchianBreakout, MacdSignal, RsiReversal, SmaCrossover,
44 SuperTrendFollow,
45};
46
47/// Context passed to strategy on each candle.
48///
49/// Provides access to historical data, current position, and pre-computed indicators.
50#[non_exhaustive]
51pub struct StrategyContext<'a> {
52 /// All candles up to and including current
53 pub candles: &'a [Candle],
54
55 /// Current candle index (0-based)
56 pub index: usize,
57
58 /// Current position (if any)
59 pub position: Option<&'a Position>,
60
61 /// Current portfolio equity
62 pub equity: f64,
63
64 /// Pre-computed indicator values (keyed by indicator name)
65 pub indicators: &'a HashMap<String, Vec<Option<f64>>>,
66}
67
68impl<'a> StrategyContext<'a> {
69 /// Get current candle
70 pub fn current_candle(&self) -> &Candle {
71 &self.candles[self.index]
72 }
73
74 /// Get previous candle (None if at start)
75 pub fn previous_candle(&self) -> Option<&Candle> {
76 if self.index > 0 {
77 Some(&self.candles[self.index - 1])
78 } else {
79 None
80 }
81 }
82
83 /// Get candle at specific index (None if out of bounds)
84 pub fn candle_at(&self, index: usize) -> Option<&Candle> {
85 self.candles.get(index)
86 }
87
88 /// Get indicator value at current index
89 pub fn indicator(&self, name: &str) -> Option<f64> {
90 self.indicators
91 .get(name)
92 .and_then(|v| v.get(self.index))
93 .and_then(|&v| v)
94 }
95
96 /// Get indicator value at specific index
97 pub fn indicator_at(&self, name: &str, index: usize) -> Option<f64> {
98 self.indicators
99 .get(name)
100 .and_then(|v| v.get(index))
101 .and_then(|&v| v)
102 }
103
104 /// Get indicator value at previous index
105 pub fn indicator_prev(&self, name: &str) -> Option<f64> {
106 if self.index > 0 {
107 self.indicator_at(name, self.index - 1)
108 } else {
109 None
110 }
111 }
112
113 /// Check if we have a position
114 pub fn has_position(&self) -> bool {
115 self.position.is_some()
116 }
117
118 /// Check if we have a long position
119 pub fn is_long(&self) -> bool {
120 self.position
121 .map(|p| matches!(p.side, PositionSide::Long))
122 .unwrap_or(false)
123 }
124
125 /// Check if we have a short position
126 pub fn is_short(&self) -> bool {
127 self.position
128 .map(|p| matches!(p.side, PositionSide::Short))
129 .unwrap_or(false)
130 }
131
132 /// Get current close price
133 pub fn close(&self) -> f64 {
134 self.current_candle().close
135 }
136
137 /// Get current open price
138 pub fn open(&self) -> f64 {
139 self.current_candle().open
140 }
141
142 /// Get current high price
143 pub fn high(&self) -> f64 {
144 self.current_candle().high
145 }
146
147 /// Get current low price
148 pub fn low(&self) -> f64 {
149 self.current_candle().low
150 }
151
152 /// Get current volume
153 pub fn volume(&self) -> i64 {
154 self.current_candle().volume
155 }
156
157 /// Get current timestamp
158 pub fn timestamp(&self) -> i64 {
159 self.current_candle().timestamp
160 }
161
162 /// Create a Long signal from the current candle's timestamp and close price.
163 pub fn signal_long(&self) -> Signal {
164 Signal::long(self.timestamp(), self.close())
165 }
166
167 /// Create a Short signal from the current candle's timestamp and close price.
168 pub fn signal_short(&self) -> Signal {
169 Signal::short(self.timestamp(), self.close())
170 }
171
172 /// Create an Exit signal from the current candle's timestamp and close price.
173 pub fn signal_exit(&self) -> Signal {
174 Signal::exit(self.timestamp(), self.close())
175 }
176
177 /// Check if crossover occurred (fast crosses above slow)
178 pub fn crossed_above(&self, fast_name: &str, slow_name: &str) -> bool {
179 let fast_now = self.indicator(fast_name);
180 let slow_now = self.indicator(slow_name);
181 let fast_prev = self.indicator_prev(fast_name);
182 let slow_prev = self.indicator_prev(slow_name);
183
184 match (fast_now, slow_now, fast_prev, slow_prev) {
185 (Some(f), Some(s), Some(fp), Some(sp)) => fp < sp && f > s, // Fixed: changed <= to <
186 _ => false,
187 }
188 }
189
190 /// Check if crossover occurred (fast crosses below slow)
191 pub fn crossed_below(&self, fast_name: &str, slow_name: &str) -> bool {
192 let fast_now = self.indicator(fast_name);
193 let slow_now = self.indicator(slow_name);
194 let fast_prev = self.indicator_prev(fast_name);
195 let slow_prev = self.indicator_prev(slow_name);
196
197 match (fast_now, slow_now, fast_prev, slow_prev) {
198 (Some(f), Some(s), Some(fp), Some(sp)) => fp > sp && f < s, // Fixed: changed >= to >
199 _ => false,
200 }
201 }
202
203 /// Check if indicator crossed above a threshold.
204 ///
205 /// Returns `true` when `prev <= threshold` **and** `current > threshold`.
206 /// The inclusive lower bound (`<=`) means a signal fires even when the
207 /// previous bar sat exactly on the threshold, which is the conventional
208 /// "crosses above" definition. This is intentionally asymmetric with the
209 /// strict crossover check in [`crossed_above`](Self::crossed_above) where
210 /// both sides use strict inequalities — threshold crossings and
211 /// indicator-vs-indicator crossings have different semantics.
212 pub fn indicator_crossed_above(&self, name: &str, threshold: f64) -> bool {
213 let now = self.indicator(name);
214 let prev = self.indicator_prev(name);
215
216 match (now, prev) {
217 (Some(n), Some(p)) => p <= threshold && n > threshold,
218 _ => false,
219 }
220 }
221
222 /// Check if indicator crossed below a threshold.
223 ///
224 /// Returns `true` when `prev >= threshold` **and** `current < threshold`.
225 /// See [`indicator_crossed_above`](Self::indicator_crossed_above) for the
226 /// rationale behind the inclusive/exclusive choice on each side.
227 pub fn indicator_crossed_below(&self, name: &str, threshold: f64) -> bool {
228 let now = self.indicator(name);
229 let prev = self.indicator_prev(name);
230
231 match (now, prev) {
232 (Some(n), Some(p)) => p >= threshold && n < threshold,
233 _ => false,
234 }
235 }
236}
237
238/// Core strategy trait - implement this for custom strategies.
239///
240/// # Example
241///
242/// ```ignore
243/// use finance_query::backtesting::{Strategy, StrategyContext, Signal};
244/// use finance_query::indicators::Indicator;
245///
246/// struct MyStrategy {
247/// sma_period: usize,
248/// }
249///
250/// impl Strategy for MyStrategy {
251/// fn name(&self) -> &str {
252/// "My Custom Strategy"
253/// }
254///
255/// fn required_indicators(&self) -> Vec<(String, Indicator)> {
256/// vec![
257/// (format!("sma_{}", self.sma_period), Indicator::Sma(self.sma_period)),
258/// ]
259/// }
260///
261/// fn on_candle(&self, ctx: &StrategyContext) -> Signal {
262/// let sma = ctx.indicator(&format!("sma_{}", self.sma_period));
263/// let close = ctx.close();
264///
265/// match sma {
266/// Some(sma_val) if close > sma_val && !ctx.has_position() => {
267/// Signal::long(ctx.timestamp(), close)
268/// }
269/// Some(sma_val) if close < sma_val && ctx.is_long() => {
270/// Signal::exit(ctx.timestamp(), close)
271/// }
272/// _ => Signal::hold(),
273/// }
274/// }
275/// }
276/// ```
277pub trait Strategy: Send + Sync {
278 /// Strategy name (for reporting)
279 fn name(&self) -> &str;
280
281 /// Required indicators this strategy needs.
282 ///
283 /// Returns list of (indicator_name, Indicator) pairs.
284 /// The engine will pre-compute these and make them available via `StrategyContext::indicator()`.
285 fn required_indicators(&self) -> Vec<(String, Indicator)>;
286
287 /// Higher-timeframe indicators required by this strategy.
288 ///
289 /// The engine resamples candles to each unique interval, computes the
290 /// listed indicators on the resampled data, and stores stretched
291 /// (base-timeframe-length) arrays in `StrategyContext::indicators` under
292 /// the `htf_key` names. Strategies built with [`StrategyBuilder`] implement
293 /// this automatically; raw [`Strategy`] implementations that use HTF
294 /// conditions should override this to avoid the O(n²) dynamic fallback.
295 ///
296 /// [`StrategyBuilder`]: crate::backtesting::strategy::StrategyBuilder
297 fn htf_requirements(&self) -> Vec<HtfIndicatorSpec> {
298 vec![]
299 }
300
301 /// Called on each candle to generate a signal.
302 ///
303 /// Return `Signal::hold()` for no action, `Signal::long()` to enter long,
304 /// `Signal::short()` to enter short, or `Signal::exit()` to close position.
305 fn on_candle(&self, ctx: &StrategyContext) -> Signal;
306
307 /// Optional: minimum candles required before strategy can generate signals.
308 /// Default is 1 (strategy can run from first candle).
309 fn warmup_period(&self) -> usize {
310 1
311 }
312}
313
314impl Strategy for Box<dyn Strategy> {
315 fn name(&self) -> &str {
316 (**self).name()
317 }
318 fn required_indicators(&self) -> Vec<(String, Indicator)> {
319 (**self).required_indicators()
320 }
321 fn htf_requirements(&self) -> Vec<HtfIndicatorSpec> {
322 (**self).htf_requirements()
323 }
324 fn on_candle(&self, ctx: &StrategyContext) -> Signal {
325 (**self).on_candle(ctx)
326 }
327 fn warmup_period(&self) -> usize {
328 (**self).warmup_period()
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 struct TestStrategy;
337
338 impl Strategy for TestStrategy {
339 fn name(&self) -> &str {
340 "Test Strategy"
341 }
342
343 fn required_indicators(&self) -> Vec<(String, Indicator)> {
344 vec![("sma_10".to_string(), Indicator::Sma(10))]
345 }
346
347 fn on_candle(&self, ctx: &StrategyContext) -> Signal {
348 if ctx.index == 5 {
349 Signal::long(ctx.timestamp(), ctx.close())
350 } else {
351 Signal::hold()
352 }
353 }
354 }
355
356 #[test]
357 fn test_strategy_trait() {
358 let strategy = TestStrategy;
359 assert_eq!(strategy.name(), "Test Strategy");
360 assert_eq!(strategy.required_indicators().len(), 1);
361 assert_eq!(strategy.warmup_period(), 1);
362 }
363
364 #[test]
365 fn test_context_crossover_detection() {
366 let candles = vec![
367 Candle {
368 timestamp: 1,
369 open: 100.0,
370 high: 101.0,
371 low: 99.0,
372 close: 100.0,
373 volume: 1000,
374 adj_close: None,
375 },
376 Candle {
377 timestamp: 2,
378 open: 100.0,
379 high: 102.0,
380 low: 99.0,
381 close: 101.0,
382 volume: 1000,
383 adj_close: None,
384 },
385 ];
386
387 let mut indicators = HashMap::new();
388 indicators.insert("fast".to_string(), vec![Some(9.0), Some(11.0)]);
389 indicators.insert("slow".to_string(), vec![Some(10.0), Some(10.0)]);
390
391 let ctx = StrategyContext {
392 candles: &candles,
393 index: 1,
394 position: None,
395 equity: 10000.0,
396 indicators: &indicators,
397 };
398
399 // fast was 9 (below slow 10), now 11 (above slow 10) -> crossed above
400 assert!(ctx.crossed_above("fast", "slow"));
401 assert!(!ctx.crossed_below("fast", "slow"));
402 }
403}