finance_query/backtesting/strategy/
mod.rs1mod builder;
23pub mod prebuilt;
24
25use std::collections::HashMap;
26
27use crate::indicators::Indicator;
28use crate::models::chart::Candle;
29
30use super::position::{Position, PositionSide};
31use super::signal::Signal;
32
33pub use builder::{CustomStrategy, StrategyBuilder};
35
36pub use prebuilt::{
38 BollingerMeanReversion, DonchianBreakout, MacdSignal, RsiReversal, SmaCrossover,
39 SuperTrendFollow,
40};
41
42#[non_exhaustive]
46pub struct StrategyContext<'a> {
47 pub candles: &'a [Candle],
49
50 pub index: usize,
52
53 pub position: Option<&'a Position>,
55
56 pub equity: f64,
58
59 pub indicators: &'a HashMap<String, Vec<Option<f64>>>,
61}
62
63impl<'a> StrategyContext<'a> {
64 pub fn current_candle(&self) -> &Candle {
66 &self.candles[self.index]
67 }
68
69 pub fn previous_candle(&self) -> Option<&Candle> {
71 if self.index > 0 {
72 Some(&self.candles[self.index - 1])
73 } else {
74 None
75 }
76 }
77
78 pub fn candle_at(&self, index: usize) -> Option<&Candle> {
80 self.candles.get(index)
81 }
82
83 pub fn indicator(&self, name: &str) -> Option<f64> {
85 self.indicators
86 .get(name)
87 .and_then(|v| v.get(self.index))
88 .and_then(|&v| v)
89 }
90
91 pub fn indicator_at(&self, name: &str, index: usize) -> Option<f64> {
93 self.indicators
94 .get(name)
95 .and_then(|v| v.get(index))
96 .and_then(|&v| v)
97 }
98
99 pub fn indicator_prev(&self, name: &str) -> Option<f64> {
101 if self.index > 0 {
102 self.indicator_at(name, self.index - 1)
103 } else {
104 None
105 }
106 }
107
108 pub fn has_position(&self) -> bool {
110 self.position.is_some()
111 }
112
113 pub fn is_long(&self) -> bool {
115 self.position
116 .map(|p| matches!(p.side, PositionSide::Long))
117 .unwrap_or(false)
118 }
119
120 pub fn is_short(&self) -> bool {
122 self.position
123 .map(|p| matches!(p.side, PositionSide::Short))
124 .unwrap_or(false)
125 }
126
127 pub fn close(&self) -> f64 {
129 self.current_candle().close
130 }
131
132 pub fn open(&self) -> f64 {
134 self.current_candle().open
135 }
136
137 pub fn high(&self) -> f64 {
139 self.current_candle().high
140 }
141
142 pub fn low(&self) -> f64 {
144 self.current_candle().low
145 }
146
147 pub fn volume(&self) -> i64 {
149 self.current_candle().volume
150 }
151
152 pub fn timestamp(&self) -> i64 {
154 self.current_candle().timestamp
155 }
156
157 pub fn crossed_above(&self, fast_name: &str, slow_name: &str) -> bool {
159 let fast_now = self.indicator(fast_name);
160 let slow_now = self.indicator(slow_name);
161 let fast_prev = self.indicator_prev(fast_name);
162 let slow_prev = self.indicator_prev(slow_name);
163
164 match (fast_now, slow_now, fast_prev, slow_prev) {
165 (Some(f), Some(s), Some(fp), Some(sp)) => fp < sp && f > s, _ => false,
167 }
168 }
169
170 pub fn crossed_below(&self, fast_name: &str, slow_name: &str) -> bool {
172 let fast_now = self.indicator(fast_name);
173 let slow_now = self.indicator(slow_name);
174 let fast_prev = self.indicator_prev(fast_name);
175 let slow_prev = self.indicator_prev(slow_name);
176
177 match (fast_now, slow_now, fast_prev, slow_prev) {
178 (Some(f), Some(s), Some(fp), Some(sp)) => fp > sp && f < s, _ => false,
180 }
181 }
182
183 pub fn indicator_crossed_above(&self, name: &str, threshold: f64) -> bool {
185 let now = self.indicator(name);
186 let prev = self.indicator_prev(name);
187
188 match (now, prev) {
189 (Some(n), Some(p)) => p <= threshold && n > threshold,
190 _ => false,
191 }
192 }
193
194 pub fn indicator_crossed_below(&self, name: &str, threshold: f64) -> bool {
196 let now = self.indicator(name);
197 let prev = self.indicator_prev(name);
198
199 match (now, prev) {
200 (Some(n), Some(p)) => p >= threshold && n < threshold,
201 _ => false,
202 }
203 }
204}
205
206pub trait Strategy: Send + Sync {
246 fn name(&self) -> &str;
248
249 fn required_indicators(&self) -> Vec<(String, Indicator)>;
254
255 fn on_candle(&self, ctx: &StrategyContext) -> Signal;
260
261 fn warmup_period(&self) -> usize {
264 1
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 struct TestStrategy;
273
274 impl Strategy for TestStrategy {
275 fn name(&self) -> &str {
276 "Test Strategy"
277 }
278
279 fn required_indicators(&self) -> Vec<(String, Indicator)> {
280 vec![("sma_10".to_string(), Indicator::Sma(10))]
281 }
282
283 fn on_candle(&self, ctx: &StrategyContext) -> Signal {
284 if ctx.index == 5 {
285 Signal::long(ctx.timestamp(), ctx.close())
286 } else {
287 Signal::hold()
288 }
289 }
290 }
291
292 #[test]
293 fn test_strategy_trait() {
294 let strategy = TestStrategy;
295 assert_eq!(strategy.name(), "Test Strategy");
296 assert_eq!(strategy.required_indicators().len(), 1);
297 assert_eq!(strategy.warmup_period(), 1);
298 }
299
300 #[test]
301 fn test_context_crossover_detection() {
302 let candles = vec![
303 Candle {
304 timestamp: 1,
305 open: 100.0,
306 high: 101.0,
307 low: 99.0,
308 close: 100.0,
309 volume: 1000,
310 adj_close: None,
311 },
312 Candle {
313 timestamp: 2,
314 open: 100.0,
315 high: 102.0,
316 low: 99.0,
317 close: 101.0,
318 volume: 1000,
319 adj_close: None,
320 },
321 ];
322
323 let mut indicators = HashMap::new();
324 indicators.insert("fast".to_string(), vec![Some(9.0), Some(11.0)]);
325 indicators.insert("slow".to_string(), vec![Some(10.0), Some(10.0)]);
326
327 let ctx = StrategyContext {
328 candles: &candles,
329 index: 1,
330 position: None,
331 equity: 10000.0,
332 indicators: &indicators,
333 };
334
335 assert!(ctx.crossed_above("fast", "slow"));
337 assert!(!ctx.crossed_below("fast", "slow"));
338 }
339}