Skip to main content

finance_query/backtesting/strategy/
builder.rs

1//! Fluent strategy builder for creating custom strategies from conditions.
2//!
3//! This module provides a builder pattern for creating custom trading strategies
4//! using entry and exit conditions.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use finance_query::backtesting::strategy::StrategyBuilder;
10//! use finance_query::backtesting::refs::*;
11//! use finance_query::backtesting::condition::*;
12//!
13//! let strategy = StrategyBuilder::new("RSI Mean Reversion")
14//!     .entry(
15//!         rsi(14).crosses_below(30.0)
16//!             .and(price().above_ref(sma(200)))
17//!     )
18//!     .exit(
19//!         rsi(14).crosses_above(70.0)
20//!             .or(stop_loss(0.05))
21//!     )
22//!     .build();
23//! ```
24
25use std::collections::HashSet;
26
27use crate::backtesting::condition::{Condition, HtfIndicatorSpec};
28use crate::backtesting::signal::Signal;
29use crate::indicators::Indicator;
30
31use super::{Strategy, StrategyContext};
32
33/// Type-erased condition wrapper for storing heterogeneous conditions.
34struct BoxedCondition {
35    evaluate_fn: Box<dyn Fn(&StrategyContext) -> bool + Send + Sync>,
36    required_indicators: Vec<(String, Indicator)>,
37    htf_requirements: Vec<HtfIndicatorSpec>,
38    description: String,
39}
40
41impl BoxedCondition {
42    fn new<C: Condition>(cond: C) -> Self {
43        let required_indicators = cond.required_indicators();
44        let htf_requirements = cond.htf_requirements();
45        let description = cond.description();
46        Self {
47            evaluate_fn: Box::new(move |ctx| cond.evaluate(ctx)),
48            required_indicators,
49            htf_requirements,
50            description,
51        }
52    }
53
54    fn evaluate(&self, ctx: &StrategyContext) -> bool {
55        (self.evaluate_fn)(ctx)
56    }
57
58    fn required_indicators(&self) -> &[(String, Indicator)] {
59        &self.required_indicators
60    }
61
62    fn htf_requirements(&self) -> &[HtfIndicatorSpec] {
63        &self.htf_requirements
64    }
65
66    fn description(&self) -> &str {
67        &self.description
68    }
69}
70
71/// Builder for creating custom strategies with entry/exit conditions.
72///
73/// The builder enforces that both entry and exit conditions are provided
74/// before a strategy can be built.
75///
76/// An optional regime filter can be set at any point in the chain via
77/// [`.regime_filter()`](StrategyBuilder::regime_filter). When set, the filter
78/// is evaluated on every bar; if it returns `false`, all entry signals are
79/// suppressed. Exit signals are **never** blocked by the regime filter.
80pub struct StrategyBuilder<E = (), X = ()> {
81    name: String,
82    entry_condition: E,
83    exit_condition: X,
84    short_entry_condition: Option<BoxedCondition>,
85    short_exit_condition: Option<BoxedCondition>,
86    regime_filter: Option<BoxedCondition>,
87    warmup_override: Option<usize>,
88}
89
90impl StrategyBuilder<(), ()> {
91    /// Create a new strategy builder with a name.
92    ///
93    /// # Example
94    ///
95    /// ```ignore
96    /// let builder = StrategyBuilder::new("My Strategy");
97    /// ```
98    pub fn new(name: impl Into<String>) -> Self {
99        Self {
100            name: name.into(),
101            entry_condition: (),
102            exit_condition: (),
103            short_entry_condition: None,
104            short_exit_condition: None,
105            regime_filter: None,
106            warmup_override: None,
107        }
108    }
109}
110
111impl<X> StrategyBuilder<(), X> {
112    /// Set the entry condition for long positions.
113    ///
114    /// # Example
115    ///
116    /// ```ignore
117    /// let builder = StrategyBuilder::new("RSI Strategy")
118    ///     .entry(rsi(14).crosses_below(30.0));
119    /// ```
120    pub fn entry<C: Condition>(self, condition: C) -> StrategyBuilder<C, X> {
121        StrategyBuilder {
122            name: self.name,
123            entry_condition: condition,
124            exit_condition: self.exit_condition,
125            short_entry_condition: self.short_entry_condition,
126            short_exit_condition: self.short_exit_condition,
127            regime_filter: self.regime_filter,
128            warmup_override: self.warmup_override,
129        }
130    }
131}
132
133impl<E> StrategyBuilder<E, ()> {
134    /// Set the exit condition for long positions.
135    ///
136    /// # Example
137    ///
138    /// ```ignore
139    /// let builder = StrategyBuilder::new("RSI Strategy")
140    ///     .entry(rsi(14).crosses_below(30.0))
141    ///     .exit(rsi(14).crosses_above(70.0));
142    /// ```
143    pub fn exit<C: Condition>(self, condition: C) -> StrategyBuilder<E, C> {
144        StrategyBuilder {
145            name: self.name,
146            entry_condition: self.entry_condition,
147            exit_condition: condition,
148            short_entry_condition: self.short_entry_condition,
149            short_exit_condition: self.short_exit_condition,
150            regime_filter: self.regime_filter,
151            warmup_override: self.warmup_override,
152        }
153    }
154}
155
156impl<E, X> StrategyBuilder<E, X> {
157    /// Set a market regime filter.
158    ///
159    /// When set, entry signals (long and short) are suppressed on any bar
160    /// where the filter evaluates to `false`. Exit signals are **never**
161    /// blocked by the regime filter, ensuring open positions can always be
162    /// closed regardless of market conditions.
163    ///
164    /// The regime filter's indicators are included in `required_indicators()`
165    /// and therefore pre-computed by the engine like any other indicator.
166    ///
167    /// # Example
168    ///
169    /// ```rust,no_run
170    /// use finance_query::backtesting::strategy::StrategyBuilder;
171    /// use finance_query::backtesting::refs::*;
172    ///
173    /// // Only trade when price is above the 200-period SMA
174    /// let strategy = StrategyBuilder::new("Trend Following")
175    ///     .regime_filter(sma(200).above_ref(sma(400)))
176    ///     .entry(ema(10).crosses_above_ref(ema(30)))
177    ///     .exit(ema(10).crosses_below_ref(ema(30)))
178    ///     .build();
179    /// ```
180    pub fn regime_filter<C: Condition>(mut self, condition: C) -> Self {
181        self.regime_filter = Some(BoxedCondition::new(condition));
182        self
183    }
184}
185
186impl<E: Condition, X: Condition> StrategyBuilder<E, X> {
187    /// Enable short positions with entry and exit conditions.
188    ///
189    /// # Example
190    ///
191    /// ```ignore
192    /// let strategy = StrategyBuilder::new("RSI Strategy")
193    ///     .entry(rsi(14).crosses_below(30.0))
194    ///     .exit(rsi(14).crosses_above(70.0))
195    ///     .with_short(
196    ///         rsi(14).crosses_above(70.0),  // Short entry
197    ///         rsi(14).crosses_below(30.0),  // Short exit
198    ///     )
199    ///     .build();
200    /// ```
201    pub fn with_short<SE: Condition, SX: Condition>(mut self, entry: SE, exit: SX) -> Self {
202        self.short_entry_condition = Some(BoxedCondition::new(entry));
203        self.short_exit_condition = Some(BoxedCondition::new(exit));
204        self
205    }
206
207    /// Override the automatic warmup period with an explicit bar count.
208    ///
209    /// By default the warmup period is inferred from each indicator's
210    /// [`Indicator::warmup_bars()`] method. Use this override when the
211    /// automatic value doesn't match your specific needs.
212    ///
213    /// # Example
214    ///
215    /// ```ignore
216    /// let strategy = StrategyBuilder::new("MACD + RSI")
217    ///     .entry(macd(12, 26, 9).crosses_above_zero())
218    ///     .exit(rsi(14).crosses_above(70.0))
219    ///     .warmup(36) // explicit override
220    ///     .build();
221    /// ```
222    pub fn warmup(mut self, bars: usize) -> Self {
223        self.warmup_override = Some(bars);
224        self
225    }
226
227    /// Build the strategy.
228    ///
229    /// # Example
230    ///
231    /// ```ignore
232    /// let strategy = StrategyBuilder::new("My Strategy")
233    ///     .entry(rsi(14).crosses_below(30.0))
234    ///     .exit(rsi(14).crosses_above(70.0))
235    ///     .build();
236    /// ```
237    pub fn build(self) -> CustomStrategy<E, X> {
238        CustomStrategy {
239            name: self.name,
240            entry_condition: self.entry_condition,
241            exit_condition: self.exit_condition,
242            short_entry_condition: self.short_entry_condition,
243            short_exit_condition: self.short_exit_condition,
244            regime_filter: self.regime_filter,
245            warmup_override: self.warmup_override,
246        }
247    }
248}
249
250/// A custom strategy built from conditions.
251///
252/// This strategy evaluates entry and exit conditions on each candle
253/// and generates appropriate signals.
254pub struct CustomStrategy<E: Condition, X: Condition> {
255    name: String,
256    entry_condition: E,
257    exit_condition: X,
258    short_entry_condition: Option<BoxedCondition>,
259    short_exit_condition: Option<BoxedCondition>,
260    /// Optional market regime filter.
261    ///
262    /// When `Some`, entry signals are suppressed on bars where the filter
263    /// evaluates to `false`. Exit signals are unaffected.
264    regime_filter: Option<BoxedCondition>,
265    /// Explicit warmup period set via [`StrategyBuilder::warmup`].
266    ///
267    /// Overrides the heuristic in [`warmup_period`] when set.
268    warmup_override: Option<usize>,
269}
270
271impl<E: Condition, X: Condition> Strategy for CustomStrategy<E, X> {
272    fn name(&self) -> &str {
273        &self.name
274    }
275
276    fn required_indicators(&self) -> Vec<(String, Indicator)> {
277        let mut indicators = self.entry_condition.required_indicators();
278        indicators.extend(self.exit_condition.required_indicators());
279
280        if let Some(ref se) = self.short_entry_condition {
281            indicators.extend(se.required_indicators().iter().cloned());
282        }
283        if let Some(ref sx) = self.short_exit_condition {
284            indicators.extend(sx.required_indicators().iter().cloned());
285        }
286        if let Some(ref rf) = self.regime_filter {
287            indicators.extend(rf.required_indicators().iter().cloned());
288        }
289
290        // Deduplicate by key
291        let mut seen = HashSet::new();
292        indicators.retain(|(key, _)| seen.insert(key.clone()));
293
294        indicators
295    }
296
297    fn htf_requirements(&self) -> Vec<HtfIndicatorSpec> {
298        let mut reqs = self.entry_condition.htf_requirements();
299        reqs.extend(self.exit_condition.htf_requirements());
300
301        if let Some(ref se) = self.short_entry_condition {
302            reqs.extend(se.htf_requirements().iter().cloned());
303        }
304        if let Some(ref sx) = self.short_exit_condition {
305            reqs.extend(sx.htf_requirements().iter().cloned());
306        }
307        if let Some(ref rf) = self.regime_filter {
308            reqs.extend(rf.htf_requirements().iter().cloned());
309        }
310
311        // Deduplicate by htf_key — same stretched array cannot be stored twice
312        let mut seen = HashSet::new();
313        reqs.retain(|spec| seen.insert(spec.htf_key.clone()));
314        reqs
315    }
316
317    fn warmup_period(&self) -> usize {
318        // Explicit override wins — use it directly.
319        if let Some(n) = self.warmup_override {
320            return n;
321        }
322
323        // Use each indicator's own warmup calculation instead of parsing
324        // key suffixes (which fails for compound indicators like MACD and
325        // Bollinger).  `.warmup(n)` on the builder still overrides this.
326        let max_warmup = self
327            .required_indicators()
328            .iter()
329            .map(|(_, indicator)| indicator.warmup_bars())
330            .max()
331            .unwrap_or(1);
332
333        max_warmup + 1
334    }
335
336    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
337        let candle = ctx.current_candle();
338
339        // Check exit conditions first (for existing positions)
340        if ctx.is_long() && self.exit_condition.evaluate(ctx) {
341            return Signal::exit(candle.timestamp, candle.close)
342                .with_reason(self.exit_condition.description());
343        }
344
345        if ctx.is_short()
346            && let Some(ref exit) = self.short_exit_condition
347            && exit.evaluate(ctx)
348        {
349            return Signal::exit(candle.timestamp, candle.close)
350                .with_reason(exit.description().to_string());
351        }
352
353        // Check entry conditions (when no position)
354        if !ctx.has_position() {
355            // Regime filter gates all entries; exits are never suppressed.
356            let regime_ok = self
357                .regime_filter
358                .as_ref()
359                .is_none_or(|rf| rf.evaluate(ctx));
360
361            if regime_ok {
362                // Long entry
363                if self.entry_condition.evaluate(ctx) {
364                    return Signal::long(candle.timestamp, candle.close)
365                        .with_reason(self.entry_condition.description());
366                }
367
368                // Short entry
369                if let Some(ref entry) = self.short_entry_condition
370                    && entry.evaluate(ctx)
371                {
372                    return Signal::short(candle.timestamp, candle.close)
373                        .with_reason(entry.description().to_string());
374                }
375            }
376        }
377
378        Signal::hold()
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use std::collections::HashMap;
385
386    use super::*;
387    use crate::backtesting::condition::{always_false, always_true};
388    use crate::backtesting::signal::SignalDirection;
389    use crate::models::chart::Candle;
390
391    fn make_candle(ts: i64, close: f64) -> Candle {
392        Candle {
393            timestamp: ts,
394            open: close,
395            high: close,
396            low: close,
397            close,
398            volume: 1000,
399            adj_close: None,
400            provider_id: None,
401        }
402    }
403
404    fn make_ctx<'a>(
405        candles: &'a [Candle],
406        indicators: &'a HashMap<String, Vec<Option<f64>>>,
407    ) -> StrategyContext<'a> {
408        StrategyContext {
409            candles,
410            index: 0,
411            position: None,
412            equity: 10_000.0,
413            indicators,
414        }
415    }
416
417    #[test]
418    fn test_strategy_builder() {
419        let strategy = StrategyBuilder::new("Test Strategy")
420            .entry(always_true())
421            .exit(always_false())
422            .build();
423
424        assert_eq!(strategy.name(), "Test Strategy");
425    }
426
427    #[test]
428    fn test_strategy_builder_with_short() {
429        let strategy = StrategyBuilder::new("Test Strategy")
430            .entry(always_true())
431            .exit(always_false())
432            .with_short(always_false(), always_true())
433            .build();
434
435        assert_eq!(strategy.name(), "Test Strategy");
436        assert!(strategy.short_entry_condition.is_some());
437        assert!(strategy.short_exit_condition.is_some());
438    }
439
440    #[test]
441    fn test_required_indicators_deduplication() {
442        use crate::backtesting::condition::Above;
443        use crate::backtesting::refs::rsi;
444
445        // Create two conditions using the same indicator
446        let entry = Above::new(rsi(14), 70.0);
447        let exit = Above::new(rsi(14), 30.0);
448
449        let strategy = StrategyBuilder::new("Test").entry(entry).exit(exit).build();
450
451        let indicators = strategy.required_indicators();
452        // Should be deduplicated to just one rsi_14
453        assert_eq!(indicators.len(), 1);
454        assert_eq!(indicators[0].0, "rsi_14");
455    }
456
457    // ── Regime filter tests ────────────────────────────────────────────
458
459    #[test]
460    fn test_regime_filter_suppresses_entry_when_false() {
461        let strategy = StrategyBuilder::new("Regime Test")
462            .regime_filter(always_false()) // regime is never active
463            .entry(always_true())
464            .exit(always_false())
465            .build();
466
467        let candles = vec![make_candle(1, 100.0)];
468        let indicators = HashMap::new();
469        let ctx = make_ctx(&candles, &indicators);
470
471        // Entry should be blocked by the regime filter
472        assert_eq!(strategy.on_candle(&ctx).direction, SignalDirection::Hold);
473    }
474
475    #[test]
476    fn test_regime_filter_allows_entry_when_true() {
477        let strategy = StrategyBuilder::new("Regime Test")
478            .regime_filter(always_true()) // regime always active
479            .entry(always_true())
480            .exit(always_false())
481            .build();
482
483        let candles = vec![make_candle(1, 100.0)];
484        let indicators = HashMap::new();
485        let ctx = make_ctx(&candles, &indicators);
486
487        assert_eq!(strategy.on_candle(&ctx).direction, SignalDirection::Long);
488    }
489
490    #[test]
491    fn test_no_regime_filter_behaves_normally() {
492        let strategy = StrategyBuilder::new("No Regime")
493            .entry(always_true())
494            .exit(always_false())
495            .build();
496
497        let candles = vec![make_candle(1, 100.0)];
498        let indicators = HashMap::new();
499        let ctx = make_ctx(&candles, &indicators);
500
501        assert_eq!(strategy.on_candle(&ctx).direction, SignalDirection::Long);
502    }
503
504    #[test]
505    fn test_regime_filter_does_not_block_exit() {
506        use crate::backtesting::position::{Position, PositionSide};
507
508        let strategy = StrategyBuilder::new("Regime Exit Test")
509            .regime_filter(always_false()) // regime is off
510            .entry(always_false())
511            .exit(always_true()) // exit condition always fires
512            .build();
513
514        let candles = vec![make_candle(1, 100.0)];
515        let indicators = HashMap::new();
516
517        // Simulate an open long position using the public constructor
518        let position = Position::new(
519            PositionSide::Long,
520            1,
521            90.0,
522            10.0,
523            0.0,
524            Signal::long(1, 90.0),
525        );
526
527        let ctx = StrategyContext {
528            candles: &candles,
529            index: 0,
530            position: Some(&position),
531            equity: 10_000.0,
532            indicators: &indicators,
533        };
534
535        // Exit must fire even though regime filter is false
536        assert_eq!(strategy.on_candle(&ctx).direction, SignalDirection::Exit);
537    }
538
539    #[test]
540    fn test_regime_filter_indicators_included_in_required() {
541        use crate::backtesting::refs::{IndicatorRefExt, sma};
542        use crate::indicators::Indicator;
543
544        let strategy = StrategyBuilder::new("Regime Indicators")
545            .regime_filter(sma(200).above_ref(sma(400)))
546            .entry(always_true())
547            .exit(always_false())
548            .build();
549
550        let indicators = strategy.required_indicators();
551        let keys: Vec<&str> = indicators.iter().map(|(k, _)| k.as_str()).collect();
552
553        assert!(
554            keys.contains(&"sma_200"),
555            "sma_200 must be in required_indicators"
556        );
557        assert!(
558            keys.contains(&"sma_400"),
559            "sma_400 must be in required_indicators"
560        );
561
562        // Verify correct Indicator variants
563        let sma_200 = indicators.iter().find(|(k, _)| k == "sma_200").unwrap();
564        assert!(matches!(sma_200.1, Indicator::Sma(200)));
565    }
566
567    #[test]
568    fn test_regime_filter_callable_before_entry() {
569        // Verify the builder chain compiles when regime_filter is called first
570        let strategy = StrategyBuilder::new("Order Test")
571            .regime_filter(always_true())
572            .entry(always_true())
573            .exit(always_false())
574            .build();
575
576        assert!(strategy.regime_filter.is_some());
577    }
578
579    #[test]
580    fn test_regime_filter_warmup_accounts_for_filter_indicators() {
581        use crate::backtesting::refs::{IndicatorRefExt, sma};
582
583        let strategy = StrategyBuilder::new("Warmup Test")
584            .regime_filter(sma(400).above_ref(sma(200)))
585            .entry(always_true())
586            .exit(always_false())
587            .build();
588
589        // Warmup must be at least sma(400).warmup_bars() + 1 = 401
590        assert!(
591            strategy.warmup_period() >= 401,
592            "warmup_period must account for sma(400): got {}",
593            strategy.warmup_period()
594        );
595    }
596}