Skip to main content

finance_query/backtesting/strategy/
builder.rs

1//! Fluent strategy builder for creating custom strategies from conditions.
2//!
3//! This module provides a builder pattern for creating custom trading strategies
4//! using entry and exit conditions.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use finance_query::backtesting::strategy::StrategyBuilder;
10//! use finance_query::backtesting::refs::*;
11//! use finance_query::backtesting::condition::*;
12//!
13//! let strategy = StrategyBuilder::new("RSI Mean Reversion")
14//!     .entry(
15//!         rsi(14).crosses_below(30.0)
16//!             .and(price().above_ref(sma(200)))
17//!     )
18//!     .exit(
19//!         rsi(14).crosses_above(70.0)
20//!             .or(stop_loss(0.05))
21//!     )
22//!     .build();
23//! ```
24
25use std::collections::HashSet;
26
27use crate::backtesting::condition::{Condition, HtfIndicatorSpec};
28use crate::backtesting::signal::Signal;
29use crate::indicators::Indicator;
30
31use super::{Strategy, StrategyContext};
32
33/// Type-erased condition wrapper for storing heterogeneous conditions.
34struct BoxedCondition {
35    evaluate_fn: Box<dyn Fn(&StrategyContext) -> bool + Send + Sync>,
36    required_indicators: Vec<(String, Indicator)>,
37    htf_requirements: Vec<HtfIndicatorSpec>,
38    description: String,
39}
40
41impl BoxedCondition {
42    fn new<C: Condition>(cond: C) -> Self {
43        let required_indicators = cond.required_indicators();
44        let htf_requirements = cond.htf_requirements();
45        let description = cond.description();
46        Self {
47            evaluate_fn: Box::new(move |ctx| cond.evaluate(ctx)),
48            required_indicators,
49            htf_requirements,
50            description,
51        }
52    }
53
54    fn evaluate(&self, ctx: &StrategyContext) -> bool {
55        (self.evaluate_fn)(ctx)
56    }
57
58    fn required_indicators(&self) -> &[(String, Indicator)] {
59        &self.required_indicators
60    }
61
62    fn htf_requirements(&self) -> &[HtfIndicatorSpec] {
63        &self.htf_requirements
64    }
65
66    fn description(&self) -> &str {
67        &self.description
68    }
69}
70
71/// Builder for creating custom strategies with entry/exit conditions.
72///
73/// The builder enforces that both entry and exit conditions are provided
74/// before a strategy can be built.
75///
76/// An optional regime filter can be set at any point in the chain via
77/// [`.regime_filter()`](StrategyBuilder::regime_filter). When set, the filter
78/// is evaluated on every bar; if it returns `false`, all entry signals are
79/// suppressed. Exit signals are **never** blocked by the regime filter.
80pub struct StrategyBuilder<E = (), X = ()> {
81    name: String,
82    entry_condition: E,
83    exit_condition: X,
84    short_entry_condition: Option<BoxedCondition>,
85    short_exit_condition: Option<BoxedCondition>,
86    regime_filter: Option<BoxedCondition>,
87    warmup_override: Option<usize>,
88}
89
90impl StrategyBuilder<(), ()> {
91    /// Create a new strategy builder with a name.
92    ///
93    /// # Example
94    ///
95    /// ```ignore
96    /// let builder = StrategyBuilder::new("My Strategy");
97    /// ```
98    pub fn new(name: impl Into<String>) -> Self {
99        Self {
100            name: name.into(),
101            entry_condition: (),
102            exit_condition: (),
103            short_entry_condition: None,
104            short_exit_condition: None,
105            regime_filter: None,
106            warmup_override: None,
107        }
108    }
109}
110
111impl<X> StrategyBuilder<(), X> {
112    /// Set the entry condition for long positions.
113    ///
114    /// # Example
115    ///
116    /// ```ignore
117    /// let builder = StrategyBuilder::new("RSI Strategy")
118    ///     .entry(rsi(14).crosses_below(30.0));
119    /// ```
120    pub fn entry<C: Condition>(self, condition: C) -> StrategyBuilder<C, X> {
121        StrategyBuilder {
122            name: self.name,
123            entry_condition: condition,
124            exit_condition: self.exit_condition,
125            short_entry_condition: self.short_entry_condition,
126            short_exit_condition: self.short_exit_condition,
127            regime_filter: self.regime_filter,
128            warmup_override: self.warmup_override,
129        }
130    }
131}
132
133impl<E> StrategyBuilder<E, ()> {
134    /// Set the exit condition for long positions.
135    ///
136    /// # Example
137    ///
138    /// ```ignore
139    /// let builder = StrategyBuilder::new("RSI Strategy")
140    ///     .entry(rsi(14).crosses_below(30.0))
141    ///     .exit(rsi(14).crosses_above(70.0));
142    /// ```
143    pub fn exit<C: Condition>(self, condition: C) -> StrategyBuilder<E, C> {
144        StrategyBuilder {
145            name: self.name,
146            entry_condition: self.entry_condition,
147            exit_condition: condition,
148            short_entry_condition: self.short_entry_condition,
149            short_exit_condition: self.short_exit_condition,
150            regime_filter: self.regime_filter,
151            warmup_override: self.warmup_override,
152        }
153    }
154}
155
156impl<E, X> StrategyBuilder<E, X> {
157    /// Set a market regime filter.
158    ///
159    /// When set, entry signals (long and short) are suppressed on any bar
160    /// where the filter evaluates to `false`. Exit signals are **never**
161    /// blocked by the regime filter, ensuring open positions can always be
162    /// closed regardless of market conditions.
163    ///
164    /// The regime filter's indicators are included in `required_indicators()`
165    /// and therefore pre-computed by the engine like any other indicator.
166    ///
167    /// # Example
168    ///
169    /// ```rust,no_run
170    /// use finance_query::backtesting::strategy::StrategyBuilder;
171    /// use finance_query::backtesting::refs::*;
172    ///
173    /// // Only trade when price is above the 200-period SMA
174    /// let strategy = StrategyBuilder::new("Trend Following")
175    ///     .regime_filter(sma(200).above_ref(sma(400)))
176    ///     .entry(ema(10).crosses_above_ref(ema(30)))
177    ///     .exit(ema(10).crosses_below_ref(ema(30)))
178    ///     .build();
179    /// ```
180    pub fn regime_filter<C: Condition>(mut self, condition: C) -> Self {
181        self.regime_filter = Some(BoxedCondition::new(condition));
182        self
183    }
184}
185
186impl<E: Condition, X: Condition> StrategyBuilder<E, X> {
187    /// Enable short positions with entry and exit conditions.
188    ///
189    /// # Example
190    ///
191    /// ```ignore
192    /// let strategy = StrategyBuilder::new("RSI Strategy")
193    ///     .entry(rsi(14).crosses_below(30.0))
194    ///     .exit(rsi(14).crosses_above(70.0))
195    ///     .with_short(
196    ///         rsi(14).crosses_above(70.0),  // Short entry
197    ///         rsi(14).crosses_below(30.0),  // Short exit
198    ///     )
199    ///     .build();
200    /// ```
201    pub fn with_short<SE: Condition, SX: Condition>(mut self, entry: SE, exit: SX) -> Self {
202        self.short_entry_condition = Some(BoxedCondition::new(entry));
203        self.short_exit_condition = Some(BoxedCondition::new(exit));
204        self
205    }
206
207    /// Override the automatic warmup period with an explicit bar count.
208    ///
209    /// By default the warmup period is inferred from each indicator's
210    /// [`Indicator::warmup_bars()`] method. Use this override when the
211    /// automatic value doesn't match your specific needs.
212    ///
213    /// # Example
214    ///
215    /// ```ignore
216    /// let strategy = StrategyBuilder::new("MACD + RSI")
217    ///     .entry(macd(12, 26, 9).crosses_above_zero())
218    ///     .exit(rsi(14).crosses_above(70.0))
219    ///     .warmup(36) // explicit override
220    ///     .build();
221    /// ```
222    pub fn warmup(mut self, bars: usize) -> Self {
223        self.warmup_override = Some(bars);
224        self
225    }
226
227    /// Build the strategy.
228    ///
229    /// # Example
230    ///
231    /// ```ignore
232    /// let strategy = StrategyBuilder::new("My Strategy")
233    ///     .entry(rsi(14).crosses_below(30.0))
234    ///     .exit(rsi(14).crosses_above(70.0))
235    ///     .build();
236    /// ```
237    pub fn build(self) -> CustomStrategy<E, X> {
238        CustomStrategy {
239            name: self.name,
240            entry_condition: self.entry_condition,
241            exit_condition: self.exit_condition,
242            short_entry_condition: self.short_entry_condition,
243            short_exit_condition: self.short_exit_condition,
244            regime_filter: self.regime_filter,
245            warmup_override: self.warmup_override,
246        }
247    }
248}
249
250/// A custom strategy built from conditions.
251///
252/// This strategy evaluates entry and exit conditions on each candle
253/// and generates appropriate signals.
254pub struct CustomStrategy<E: Condition, X: Condition> {
255    name: String,
256    entry_condition: E,
257    exit_condition: X,
258    short_entry_condition: Option<BoxedCondition>,
259    short_exit_condition: Option<BoxedCondition>,
260    /// Optional market regime filter.
261    ///
262    /// When `Some`, entry signals are suppressed on bars where the filter
263    /// evaluates to `false`. Exit signals are unaffected.
264    regime_filter: Option<BoxedCondition>,
265    /// Explicit warmup period set via [`StrategyBuilder::warmup`].
266    ///
267    /// Overrides the heuristic in [`warmup_period`] when set.
268    warmup_override: Option<usize>,
269}
270
271impl<E: Condition, X: Condition> Strategy for CustomStrategy<E, X> {
272    fn name(&self) -> &str {
273        &self.name
274    }
275
276    fn required_indicators(&self) -> Vec<(String, Indicator)> {
277        let mut indicators = self.entry_condition.required_indicators();
278        indicators.extend(self.exit_condition.required_indicators());
279
280        if let Some(ref se) = self.short_entry_condition {
281            indicators.extend(se.required_indicators().iter().cloned());
282        }
283        if let Some(ref sx) = self.short_exit_condition {
284            indicators.extend(sx.required_indicators().iter().cloned());
285        }
286        if let Some(ref rf) = self.regime_filter {
287            indicators.extend(rf.required_indicators().iter().cloned());
288        }
289
290        // Deduplicate by key
291        let mut seen = HashSet::new();
292        indicators.retain(|(key, _)| seen.insert(key.clone()));
293
294        indicators
295    }
296
297    fn htf_requirements(&self) -> Vec<HtfIndicatorSpec> {
298        let mut reqs = self.entry_condition.htf_requirements();
299        reqs.extend(self.exit_condition.htf_requirements());
300
301        if let Some(ref se) = self.short_entry_condition {
302            reqs.extend(se.htf_requirements().iter().cloned());
303        }
304        if let Some(ref sx) = self.short_exit_condition {
305            reqs.extend(sx.htf_requirements().iter().cloned());
306        }
307        if let Some(ref rf) = self.regime_filter {
308            reqs.extend(rf.htf_requirements().iter().cloned());
309        }
310
311        // Deduplicate by htf_key — same stretched array cannot be stored twice
312        let mut seen = HashSet::new();
313        reqs.retain(|spec| seen.insert(spec.htf_key.clone()));
314        reqs
315    }
316
317    fn warmup_period(&self) -> usize {
318        // Explicit override wins — use it directly.
319        if let Some(n) = self.warmup_override {
320            return n;
321        }
322
323        // Use each indicator's own warmup calculation instead of parsing
324        // key suffixes (which fails for compound indicators like MACD and
325        // Bollinger).  `.warmup(n)` on the builder still overrides this.
326        let max_warmup = self
327            .required_indicators()
328            .iter()
329            .map(|(_, indicator)| indicator.warmup_bars())
330            .max()
331            .unwrap_or(1);
332
333        max_warmup + 1
334    }
335
336    fn on_candle(&self, ctx: &StrategyContext) -> Signal {
337        let candle = ctx.current_candle();
338
339        // Check exit conditions first (for existing positions)
340        if ctx.is_long() && self.exit_condition.evaluate(ctx) {
341            return Signal::exit(candle.timestamp, candle.close)
342                .with_reason(self.exit_condition.description());
343        }
344
345        if ctx.is_short()
346            && let Some(ref exit) = self.short_exit_condition
347            && exit.evaluate(ctx)
348        {
349            return Signal::exit(candle.timestamp, candle.close)
350                .with_reason(exit.description().to_string());
351        }
352
353        // Check entry conditions (when no position)
354        if !ctx.has_position() {
355            // Regime filter gates all entries; exits are never suppressed.
356            let regime_ok = self
357                .regime_filter
358                .as_ref()
359                .is_none_or(|rf| rf.evaluate(ctx));
360
361            if regime_ok {
362                // Long entry
363                if self.entry_condition.evaluate(ctx) {
364                    return Signal::long(candle.timestamp, candle.close)
365                        .with_reason(self.entry_condition.description());
366                }
367
368                // Short entry
369                if let Some(ref entry) = self.short_entry_condition
370                    && entry.evaluate(ctx)
371                {
372                    return Signal::short(candle.timestamp, candle.close)
373                        .with_reason(entry.description().to_string());
374                }
375            }
376        }
377
378        Signal::hold()
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use std::collections::HashMap;
385
386    use super::*;
387    use crate::backtesting::condition::{always_false, always_true};
388    use crate::backtesting::signal::SignalDirection;
389    use crate::models::chart::Candle;
390
391    fn make_candle(ts: i64, close: f64) -> Candle {
392        Candle {
393            timestamp: ts,
394            open: close,
395            high: close,
396            low: close,
397            close,
398            volume: 1000,
399            adj_close: None,
400        }
401    }
402
403    fn make_ctx<'a>(
404        candles: &'a [Candle],
405        indicators: &'a HashMap<String, Vec<Option<f64>>>,
406    ) -> StrategyContext<'a> {
407        StrategyContext {
408            candles,
409            index: 0,
410            position: None,
411            equity: 10_000.0,
412            indicators,
413        }
414    }
415
416    #[test]
417    fn test_strategy_builder() {
418        let strategy = StrategyBuilder::new("Test Strategy")
419            .entry(always_true())
420            .exit(always_false())
421            .build();
422
423        assert_eq!(strategy.name(), "Test Strategy");
424    }
425
426    #[test]
427    fn test_strategy_builder_with_short() {
428        let strategy = StrategyBuilder::new("Test Strategy")
429            .entry(always_true())
430            .exit(always_false())
431            .with_short(always_false(), always_true())
432            .build();
433
434        assert_eq!(strategy.name(), "Test Strategy");
435        assert!(strategy.short_entry_condition.is_some());
436        assert!(strategy.short_exit_condition.is_some());
437    }
438
439    #[test]
440    fn test_required_indicators_deduplication() {
441        use crate::backtesting::condition::Above;
442        use crate::backtesting::refs::rsi;
443
444        // Create two conditions using the same indicator
445        let entry = Above::new(rsi(14), 70.0);
446        let exit = Above::new(rsi(14), 30.0);
447
448        let strategy = StrategyBuilder::new("Test").entry(entry).exit(exit).build();
449
450        let indicators = strategy.required_indicators();
451        // Should be deduplicated to just one rsi_14
452        assert_eq!(indicators.len(), 1);
453        assert_eq!(indicators[0].0, "rsi_14");
454    }
455
456    // ── Regime filter tests ────────────────────────────────────────────
457
458    #[test]
459    fn test_regime_filter_suppresses_entry_when_false() {
460        let strategy = StrategyBuilder::new("Regime Test")
461            .regime_filter(always_false()) // regime is never active
462            .entry(always_true())
463            .exit(always_false())
464            .build();
465
466        let candles = vec![make_candle(1, 100.0)];
467        let indicators = HashMap::new();
468        let ctx = make_ctx(&candles, &indicators);
469
470        // Entry should be blocked by the regime filter
471        assert_eq!(strategy.on_candle(&ctx).direction, SignalDirection::Hold);
472    }
473
474    #[test]
475    fn test_regime_filter_allows_entry_when_true() {
476        let strategy = StrategyBuilder::new("Regime Test")
477            .regime_filter(always_true()) // regime always active
478            .entry(always_true())
479            .exit(always_false())
480            .build();
481
482        let candles = vec![make_candle(1, 100.0)];
483        let indicators = HashMap::new();
484        let ctx = make_ctx(&candles, &indicators);
485
486        assert_eq!(strategy.on_candle(&ctx).direction, SignalDirection::Long);
487    }
488
489    #[test]
490    fn test_no_regime_filter_behaves_normally() {
491        let strategy = StrategyBuilder::new("No Regime")
492            .entry(always_true())
493            .exit(always_false())
494            .build();
495
496        let candles = vec![make_candle(1, 100.0)];
497        let indicators = HashMap::new();
498        let ctx = make_ctx(&candles, &indicators);
499
500        assert_eq!(strategy.on_candle(&ctx).direction, SignalDirection::Long);
501    }
502
503    #[test]
504    fn test_regime_filter_does_not_block_exit() {
505        use crate::backtesting::position::{Position, PositionSide};
506
507        let strategy = StrategyBuilder::new("Regime Exit Test")
508            .regime_filter(always_false()) // regime is off
509            .entry(always_false())
510            .exit(always_true()) // exit condition always fires
511            .build();
512
513        let candles = vec![make_candle(1, 100.0)];
514        let indicators = HashMap::new();
515
516        // Simulate an open long position using the public constructor
517        let position = Position::new(
518            PositionSide::Long,
519            1,
520            90.0,
521            10.0,
522            0.0,
523            Signal::long(1, 90.0),
524        );
525
526        let ctx = StrategyContext {
527            candles: &candles,
528            index: 0,
529            position: Some(&position),
530            equity: 10_000.0,
531            indicators: &indicators,
532        };
533
534        // Exit must fire even though regime filter is false
535        assert_eq!(strategy.on_candle(&ctx).direction, SignalDirection::Exit);
536    }
537
538    #[test]
539    fn test_regime_filter_indicators_included_in_required() {
540        use crate::backtesting::refs::{IndicatorRefExt, sma};
541        use crate::indicators::Indicator;
542
543        let strategy = StrategyBuilder::new("Regime Indicators")
544            .regime_filter(sma(200).above_ref(sma(400)))
545            .entry(always_true())
546            .exit(always_false())
547            .build();
548
549        let indicators = strategy.required_indicators();
550        let keys: Vec<&str> = indicators.iter().map(|(k, _)| k.as_str()).collect();
551
552        assert!(
553            keys.contains(&"sma_200"),
554            "sma_200 must be in required_indicators"
555        );
556        assert!(
557            keys.contains(&"sma_400"),
558            "sma_400 must be in required_indicators"
559        );
560
561        // Verify correct Indicator variants
562        let sma_200 = indicators.iter().find(|(k, _)| k == "sma_200").unwrap();
563        assert!(matches!(sma_200.1, Indicator::Sma(200)));
564    }
565
566    #[test]
567    fn test_regime_filter_callable_before_entry() {
568        // Verify the builder chain compiles when regime_filter is called first
569        let strategy = StrategyBuilder::new("Order Test")
570            .regime_filter(always_true())
571            .entry(always_true())
572            .exit(always_false())
573            .build();
574
575        assert!(strategy.regime_filter.is_some());
576    }
577
578    #[test]
579    fn test_regime_filter_warmup_accounts_for_filter_indicators() {
580        use crate::backtesting::refs::{IndicatorRefExt, sma};
581
582        let strategy = StrategyBuilder::new("Warmup Test")
583            .regime_filter(sma(400).above_ref(sma(200)))
584            .entry(always_true())
585            .exit(always_false())
586            .build();
587
588        // Warmup must be at least sma(400).warmup_bars() + 1 = 401
589        assert!(
590            strategy.warmup_period() >= 401,
591            "warmup_period must account for sma(400): got {}",
592            strategy.warmup_period()
593        );
594    }
595}