finance_query/backtesting/strategy/
builder.rs1use std::collections::HashSet;
26
27use crate::backtesting::condition::{Condition, HtfIndicatorSpec};
28use crate::backtesting::signal::Signal;
29use crate::indicators::Indicator;
30
31use super::{Strategy, StrategyContext};
32
33struct BoxedCondition {
35 evaluate_fn: Box<dyn Fn(&StrategyContext) -> bool + Send + Sync>,
36 required_indicators: Vec<(String, Indicator)>,
37 htf_requirements: Vec<HtfIndicatorSpec>,
38 description: String,
39}
40
41impl BoxedCondition {
42 fn new<C: Condition>(cond: C) -> Self {
43 let required_indicators = cond.required_indicators();
44 let htf_requirements = cond.htf_requirements();
45 let description = cond.description();
46 Self {
47 evaluate_fn: Box::new(move |ctx| cond.evaluate(ctx)),
48 required_indicators,
49 htf_requirements,
50 description,
51 }
52 }
53
54 fn evaluate(&self, ctx: &StrategyContext) -> bool {
55 (self.evaluate_fn)(ctx)
56 }
57
58 fn required_indicators(&self) -> &[(String, Indicator)] {
59 &self.required_indicators
60 }
61
62 fn htf_requirements(&self) -> &[HtfIndicatorSpec] {
63 &self.htf_requirements
64 }
65
66 fn description(&self) -> &str {
67 &self.description
68 }
69}
70
71pub struct StrategyBuilder<E = (), X = ()> {
81 name: String,
82 entry_condition: E,
83 exit_condition: X,
84 short_entry_condition: Option<BoxedCondition>,
85 short_exit_condition: Option<BoxedCondition>,
86 regime_filter: Option<BoxedCondition>,
87 warmup_override: Option<usize>,
88}
89
90impl StrategyBuilder<(), ()> {
91 pub fn new(name: impl Into<String>) -> Self {
99 Self {
100 name: name.into(),
101 entry_condition: (),
102 exit_condition: (),
103 short_entry_condition: None,
104 short_exit_condition: None,
105 regime_filter: None,
106 warmup_override: None,
107 }
108 }
109}
110
111impl<X> StrategyBuilder<(), X> {
112 pub fn entry<C: Condition>(self, condition: C) -> StrategyBuilder<C, X> {
121 StrategyBuilder {
122 name: self.name,
123 entry_condition: condition,
124 exit_condition: self.exit_condition,
125 short_entry_condition: self.short_entry_condition,
126 short_exit_condition: self.short_exit_condition,
127 regime_filter: self.regime_filter,
128 warmup_override: self.warmup_override,
129 }
130 }
131}
132
133impl<E> StrategyBuilder<E, ()> {
134 pub fn exit<C: Condition>(self, condition: C) -> StrategyBuilder<E, C> {
144 StrategyBuilder {
145 name: self.name,
146 entry_condition: self.entry_condition,
147 exit_condition: condition,
148 short_entry_condition: self.short_entry_condition,
149 short_exit_condition: self.short_exit_condition,
150 regime_filter: self.regime_filter,
151 warmup_override: self.warmup_override,
152 }
153 }
154}
155
156impl<E, X> StrategyBuilder<E, X> {
157 pub fn regime_filter<C: Condition>(mut self, condition: C) -> Self {
181 self.regime_filter = Some(BoxedCondition::new(condition));
182 self
183 }
184}
185
186impl<E: Condition, X: Condition> StrategyBuilder<E, X> {
187 pub fn with_short<SE: Condition, SX: Condition>(mut self, entry: SE, exit: SX) -> Self {
202 self.short_entry_condition = Some(BoxedCondition::new(entry));
203 self.short_exit_condition = Some(BoxedCondition::new(exit));
204 self
205 }
206
207 pub fn warmup(mut self, bars: usize) -> Self {
223 self.warmup_override = Some(bars);
224 self
225 }
226
227 pub fn build(self) -> CustomStrategy<E, X> {
238 CustomStrategy {
239 name: self.name,
240 entry_condition: self.entry_condition,
241 exit_condition: self.exit_condition,
242 short_entry_condition: self.short_entry_condition,
243 short_exit_condition: self.short_exit_condition,
244 regime_filter: self.regime_filter,
245 warmup_override: self.warmup_override,
246 }
247 }
248}
249
250pub struct CustomStrategy<E: Condition, X: Condition> {
255 name: String,
256 entry_condition: E,
257 exit_condition: X,
258 short_entry_condition: Option<BoxedCondition>,
259 short_exit_condition: Option<BoxedCondition>,
260 regime_filter: Option<BoxedCondition>,
265 warmup_override: Option<usize>,
269}
270
271impl<E: Condition, X: Condition> Strategy for CustomStrategy<E, X> {
272 fn name(&self) -> &str {
273 &self.name
274 }
275
276 fn required_indicators(&self) -> Vec<(String, Indicator)> {
277 let mut indicators = self.entry_condition.required_indicators();
278 indicators.extend(self.exit_condition.required_indicators());
279
280 if let Some(ref se) = self.short_entry_condition {
281 indicators.extend(se.required_indicators().iter().cloned());
282 }
283 if let Some(ref sx) = self.short_exit_condition {
284 indicators.extend(sx.required_indicators().iter().cloned());
285 }
286 if let Some(ref rf) = self.regime_filter {
287 indicators.extend(rf.required_indicators().iter().cloned());
288 }
289
290 let mut seen = HashSet::new();
292 indicators.retain(|(key, _)| seen.insert(key.clone()));
293
294 indicators
295 }
296
297 fn htf_requirements(&self) -> Vec<HtfIndicatorSpec> {
298 let mut reqs = self.entry_condition.htf_requirements();
299 reqs.extend(self.exit_condition.htf_requirements());
300
301 if let Some(ref se) = self.short_entry_condition {
302 reqs.extend(se.htf_requirements().iter().cloned());
303 }
304 if let Some(ref sx) = self.short_exit_condition {
305 reqs.extend(sx.htf_requirements().iter().cloned());
306 }
307 if let Some(ref rf) = self.regime_filter {
308 reqs.extend(rf.htf_requirements().iter().cloned());
309 }
310
311 let mut seen = HashSet::new();
313 reqs.retain(|spec| seen.insert(spec.htf_key.clone()));
314 reqs
315 }
316
317 fn warmup_period(&self) -> usize {
318 if let Some(n) = self.warmup_override {
320 return n;
321 }
322
323 let max_warmup = self
327 .required_indicators()
328 .iter()
329 .map(|(_, indicator)| indicator.warmup_bars())
330 .max()
331 .unwrap_or(1);
332
333 max_warmup + 1
334 }
335
336 fn on_candle(&self, ctx: &StrategyContext) -> Signal {
337 let candle = ctx.current_candle();
338
339 if ctx.is_long() && self.exit_condition.evaluate(ctx) {
341 return Signal::exit(candle.timestamp, candle.close)
342 .with_reason(self.exit_condition.description());
343 }
344
345 if ctx.is_short()
346 && let Some(ref exit) = self.short_exit_condition
347 && exit.evaluate(ctx)
348 {
349 return Signal::exit(candle.timestamp, candle.close)
350 .with_reason(exit.description().to_string());
351 }
352
353 if !ctx.has_position() {
355 let regime_ok = self
357 .regime_filter
358 .as_ref()
359 .is_none_or(|rf| rf.evaluate(ctx));
360
361 if regime_ok {
362 if self.entry_condition.evaluate(ctx) {
364 return Signal::long(candle.timestamp, candle.close)
365 .with_reason(self.entry_condition.description());
366 }
367
368 if let Some(ref entry) = self.short_entry_condition
370 && entry.evaluate(ctx)
371 {
372 return Signal::short(candle.timestamp, candle.close)
373 .with_reason(entry.description().to_string());
374 }
375 }
376 }
377
378 Signal::hold()
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use std::collections::HashMap;
385
386 use super::*;
387 use crate::backtesting::condition::{always_false, always_true};
388 use crate::backtesting::signal::SignalDirection;
389 use crate::models::chart::Candle;
390
391 fn make_candle(ts: i64, close: f64) -> Candle {
392 Candle {
393 timestamp: ts,
394 open: close,
395 high: close,
396 low: close,
397 close,
398 volume: 1000,
399 adj_close: None,
400 }
401 }
402
403 fn make_ctx<'a>(
404 candles: &'a [Candle],
405 indicators: &'a HashMap<String, Vec<Option<f64>>>,
406 ) -> StrategyContext<'a> {
407 StrategyContext {
408 candles,
409 index: 0,
410 position: None,
411 equity: 10_000.0,
412 indicators,
413 }
414 }
415
416 #[test]
417 fn test_strategy_builder() {
418 let strategy = StrategyBuilder::new("Test Strategy")
419 .entry(always_true())
420 .exit(always_false())
421 .build();
422
423 assert_eq!(strategy.name(), "Test Strategy");
424 }
425
426 #[test]
427 fn test_strategy_builder_with_short() {
428 let strategy = StrategyBuilder::new("Test Strategy")
429 .entry(always_true())
430 .exit(always_false())
431 .with_short(always_false(), always_true())
432 .build();
433
434 assert_eq!(strategy.name(), "Test Strategy");
435 assert!(strategy.short_entry_condition.is_some());
436 assert!(strategy.short_exit_condition.is_some());
437 }
438
439 #[test]
440 fn test_required_indicators_deduplication() {
441 use crate::backtesting::condition::Above;
442 use crate::backtesting::refs::rsi;
443
444 let entry = Above::new(rsi(14), 70.0);
446 let exit = Above::new(rsi(14), 30.0);
447
448 let strategy = StrategyBuilder::new("Test").entry(entry).exit(exit).build();
449
450 let indicators = strategy.required_indicators();
451 assert_eq!(indicators.len(), 1);
453 assert_eq!(indicators[0].0, "rsi_14");
454 }
455
456 #[test]
459 fn test_regime_filter_suppresses_entry_when_false() {
460 let strategy = StrategyBuilder::new("Regime Test")
461 .regime_filter(always_false()) .entry(always_true())
463 .exit(always_false())
464 .build();
465
466 let candles = vec![make_candle(1, 100.0)];
467 let indicators = HashMap::new();
468 let ctx = make_ctx(&candles, &indicators);
469
470 assert_eq!(strategy.on_candle(&ctx).direction, SignalDirection::Hold);
472 }
473
474 #[test]
475 fn test_regime_filter_allows_entry_when_true() {
476 let strategy = StrategyBuilder::new("Regime Test")
477 .regime_filter(always_true()) .entry(always_true())
479 .exit(always_false())
480 .build();
481
482 let candles = vec![make_candle(1, 100.0)];
483 let indicators = HashMap::new();
484 let ctx = make_ctx(&candles, &indicators);
485
486 assert_eq!(strategy.on_candle(&ctx).direction, SignalDirection::Long);
487 }
488
489 #[test]
490 fn test_no_regime_filter_behaves_normally() {
491 let strategy = StrategyBuilder::new("No Regime")
492 .entry(always_true())
493 .exit(always_false())
494 .build();
495
496 let candles = vec![make_candle(1, 100.0)];
497 let indicators = HashMap::new();
498 let ctx = make_ctx(&candles, &indicators);
499
500 assert_eq!(strategy.on_candle(&ctx).direction, SignalDirection::Long);
501 }
502
503 #[test]
504 fn test_regime_filter_does_not_block_exit() {
505 use crate::backtesting::position::{Position, PositionSide};
506
507 let strategy = StrategyBuilder::new("Regime Exit Test")
508 .regime_filter(always_false()) .entry(always_false())
510 .exit(always_true()) .build();
512
513 let candles = vec![make_candle(1, 100.0)];
514 let indicators = HashMap::new();
515
516 let position = Position::new(
518 PositionSide::Long,
519 1,
520 90.0,
521 10.0,
522 0.0,
523 Signal::long(1, 90.0),
524 );
525
526 let ctx = StrategyContext {
527 candles: &candles,
528 index: 0,
529 position: Some(&position),
530 equity: 10_000.0,
531 indicators: &indicators,
532 };
533
534 assert_eq!(strategy.on_candle(&ctx).direction, SignalDirection::Exit);
536 }
537
538 #[test]
539 fn test_regime_filter_indicators_included_in_required() {
540 use crate::backtesting::refs::{IndicatorRefExt, sma};
541 use crate::indicators::Indicator;
542
543 let strategy = StrategyBuilder::new("Regime Indicators")
544 .regime_filter(sma(200).above_ref(sma(400)))
545 .entry(always_true())
546 .exit(always_false())
547 .build();
548
549 let indicators = strategy.required_indicators();
550 let keys: Vec<&str> = indicators.iter().map(|(k, _)| k.as_str()).collect();
551
552 assert!(
553 keys.contains(&"sma_200"),
554 "sma_200 must be in required_indicators"
555 );
556 assert!(
557 keys.contains(&"sma_400"),
558 "sma_400 must be in required_indicators"
559 );
560
561 let sma_200 = indicators.iter().find(|(k, _)| k == "sma_200").unwrap();
563 assert!(matches!(sma_200.1, Indicator::Sma(200)));
564 }
565
566 #[test]
567 fn test_regime_filter_callable_before_entry() {
568 let strategy = StrategyBuilder::new("Order Test")
570 .regime_filter(always_true())
571 .entry(always_true())
572 .exit(always_false())
573 .build();
574
575 assert!(strategy.regime_filter.is_some());
576 }
577
578 #[test]
579 fn test_regime_filter_warmup_accounts_for_filter_indicators() {
580 use crate::backtesting::refs::{IndicatorRefExt, sma};
581
582 let strategy = StrategyBuilder::new("Warmup Test")
583 .regime_filter(sma(400).above_ref(sma(200)))
584 .entry(always_true())
585 .exit(always_false())
586 .build();
587
588 assert!(
590 strategy.warmup_period() >= 401,
591 "warmup_period must account for sma(400): got {}",
592 strategy.warmup_period()
593 );
594 }
595}