finance_query/backtesting/strategy/
builder.rs1use std::collections::HashSet;
26
27use crate::backtesting::condition::Condition;
28use crate::backtesting::signal::Signal;
29use crate::indicators::Indicator;
30
31use super::{Strategy, StrategyContext};
32
33struct BoxedCondition {
35 evaluate_fn: Box<dyn Fn(&StrategyContext) -> bool + Send + Sync>,
36 required_indicators: Vec<(String, Indicator)>,
37 description: String,
38}
39
40impl BoxedCondition {
41 fn new<C: Condition>(cond: C) -> Self {
42 let required_indicators = cond.required_indicators();
43 let description = cond.description();
44 Self {
45 evaluate_fn: Box::new(move |ctx| cond.evaluate(ctx)),
46 required_indicators,
47 description,
48 }
49 }
50
51 fn evaluate(&self, ctx: &StrategyContext) -> bool {
52 (self.evaluate_fn)(ctx)
53 }
54
55 fn required_indicators(&self) -> &[(String, Indicator)] {
56 &self.required_indicators
57 }
58
59 fn description(&self) -> &str {
60 &self.description
61 }
62}
63
64pub struct StrategyBuilder<E = (), X = ()> {
69 name: String,
70 entry_condition: E,
71 exit_condition: X,
72 short_entry_condition: Option<BoxedCondition>,
73 short_exit_condition: Option<BoxedCondition>,
74}
75
76impl StrategyBuilder<(), ()> {
77 pub fn new(name: impl Into<String>) -> Self {
85 Self {
86 name: name.into(),
87 entry_condition: (),
88 exit_condition: (),
89 short_entry_condition: None,
90 short_exit_condition: None,
91 }
92 }
93}
94
95impl<X> StrategyBuilder<(), X> {
96 pub fn entry<C: Condition>(self, condition: C) -> StrategyBuilder<C, X> {
105 StrategyBuilder {
106 name: self.name,
107 entry_condition: condition,
108 exit_condition: self.exit_condition,
109 short_entry_condition: self.short_entry_condition,
110 short_exit_condition: self.short_exit_condition,
111 }
112 }
113}
114
115impl<E> StrategyBuilder<E, ()> {
116 pub fn exit<C: Condition>(self, condition: C) -> StrategyBuilder<E, C> {
126 StrategyBuilder {
127 name: self.name,
128 entry_condition: self.entry_condition,
129 exit_condition: condition,
130 short_entry_condition: self.short_entry_condition,
131 short_exit_condition: self.short_exit_condition,
132 }
133 }
134}
135
136impl<E: Condition, X: Condition> StrategyBuilder<E, X> {
137 pub fn with_short<SE: Condition, SX: Condition>(mut self, entry: SE, exit: SX) -> Self {
152 self.short_entry_condition = Some(BoxedCondition::new(entry));
153 self.short_exit_condition = Some(BoxedCondition::new(exit));
154 self
155 }
156
157 pub fn build(self) -> CustomStrategy<E, X> {
168 CustomStrategy {
169 name: self.name,
170 entry_condition: self.entry_condition,
171 exit_condition: self.exit_condition,
172 short_entry_condition: self.short_entry_condition,
173 short_exit_condition: self.short_exit_condition,
174 }
175 }
176}
177
178pub struct CustomStrategy<E: Condition, X: Condition> {
183 name: String,
184 entry_condition: E,
185 exit_condition: X,
186 short_entry_condition: Option<BoxedCondition>,
187 short_exit_condition: Option<BoxedCondition>,
188}
189
190impl<E: Condition, X: Condition> Strategy for CustomStrategy<E, X> {
191 fn name(&self) -> &str {
192 &self.name
193 }
194
195 fn required_indicators(&self) -> Vec<(String, Indicator)> {
196 let mut indicators = self.entry_condition.required_indicators();
197 indicators.extend(self.exit_condition.required_indicators());
198
199 if let Some(ref se) = self.short_entry_condition {
200 indicators.extend(se.required_indicators().iter().cloned());
201 }
202 if let Some(ref sx) = self.short_exit_condition {
203 indicators.extend(sx.required_indicators().iter().cloned());
204 }
205
206 let mut seen = HashSet::new();
208 indicators.retain(|(key, _)| seen.insert(key.clone()));
209
210 indicators
211 }
212
213 fn warmup_period(&self) -> usize {
214 let max_period = self
217 .required_indicators()
218 .iter()
219 .filter_map(|(key, _)| {
220 key.rsplit('_').next().and_then(|s| s.parse::<usize>().ok())
222 })
223 .max()
224 .unwrap_or(1);
225
226 max_period + 1
227 }
228
229 fn on_candle(&self, ctx: &StrategyContext) -> Signal {
230 let candle = ctx.current_candle();
231
232 if ctx.is_long() && self.exit_condition.evaluate(ctx) {
234 return Signal::exit(candle.timestamp, candle.close)
235 .with_reason(self.exit_condition.description());
236 }
237
238 if ctx.is_short()
239 && let Some(ref exit) = self.short_exit_condition
240 && exit.evaluate(ctx)
241 {
242 return Signal::exit(candle.timestamp, candle.close)
243 .with_reason(exit.description().to_string());
244 }
245
246 if !ctx.has_position() {
248 if self.entry_condition.evaluate(ctx) {
250 return Signal::long(candle.timestamp, candle.close)
251 .with_reason(self.entry_condition.description());
252 }
253
254 if let Some(ref entry) = self.short_entry_condition
256 && entry.evaluate(ctx)
257 {
258 return Signal::short(candle.timestamp, candle.close)
259 .with_reason(entry.description().to_string());
260 }
261 }
262
263 Signal::hold()
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use crate::backtesting::condition::{always_false, always_true};
271
272 #[test]
273 fn test_strategy_builder() {
274 let strategy = StrategyBuilder::new("Test Strategy")
275 .entry(always_true())
276 .exit(always_false())
277 .build();
278
279 assert_eq!(strategy.name(), "Test Strategy");
280 }
281
282 #[test]
283 fn test_strategy_builder_with_short() {
284 let strategy = StrategyBuilder::new("Test Strategy")
285 .entry(always_true())
286 .exit(always_false())
287 .with_short(always_false(), always_true())
288 .build();
289
290 assert_eq!(strategy.name(), "Test Strategy");
291 assert!(strategy.short_entry_condition.is_some());
292 assert!(strategy.short_exit_condition.is_some());
293 }
294
295 #[test]
296 fn test_required_indicators_deduplication() {
297 use crate::backtesting::condition::Above;
298 use crate::backtesting::refs::rsi;
299
300 let entry = Above::new(rsi(14), 70.0);
302 let exit = Above::new(rsi(14), 30.0);
303
304 let strategy = StrategyBuilder::new("Test").entry(entry).exit(exit).build();
305
306 let indicators = strategy.required_indicators();
307 assert_eq!(indicators.len(), 1);
309 assert_eq!(indicators[0].0, "rsi_14");
310 }
311}