use std::collections::HashSet;
use crate::backtesting::condition::{Condition, HtfIndicatorSpec};
use crate::backtesting::signal::Signal;
use crate::indicators::Indicator;
use super::{Strategy, StrategyContext};
struct BoxedCondition {
evaluate_fn: Box<dyn Fn(&StrategyContext) -> bool + Send + Sync>,
required_indicators: Vec<(String, Indicator)>,
htf_requirements: Vec<HtfIndicatorSpec>,
description: String,
}
impl BoxedCondition {
fn new<C: Condition>(cond: C) -> Self {
let required_indicators = cond.required_indicators();
let htf_requirements = cond.htf_requirements();
let description = cond.description();
Self {
evaluate_fn: Box::new(move |ctx| cond.evaluate(ctx)),
required_indicators,
htf_requirements,
description,
}
}
fn evaluate(&self, ctx: &StrategyContext) -> bool {
(self.evaluate_fn)(ctx)
}
fn required_indicators(&self) -> &[(String, Indicator)] {
&self.required_indicators
}
fn htf_requirements(&self) -> &[HtfIndicatorSpec] {
&self.htf_requirements
}
fn description(&self) -> &str {
&self.description
}
}
pub struct StrategyBuilder<E = (), X = ()> {
name: String,
entry_condition: E,
exit_condition: X,
short_entry_condition: Option<BoxedCondition>,
short_exit_condition: Option<BoxedCondition>,
regime_filter: Option<BoxedCondition>,
warmup_override: Option<usize>,
}
impl StrategyBuilder<(), ()> {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
entry_condition: (),
exit_condition: (),
short_entry_condition: None,
short_exit_condition: None,
regime_filter: None,
warmup_override: None,
}
}
}
impl<X> StrategyBuilder<(), X> {
pub fn entry<C: Condition>(self, condition: C) -> StrategyBuilder<C, X> {
StrategyBuilder {
name: self.name,
entry_condition: condition,
exit_condition: self.exit_condition,
short_entry_condition: self.short_entry_condition,
short_exit_condition: self.short_exit_condition,
regime_filter: self.regime_filter,
warmup_override: self.warmup_override,
}
}
}
impl<E> StrategyBuilder<E, ()> {
pub fn exit<C: Condition>(self, condition: C) -> StrategyBuilder<E, C> {
StrategyBuilder {
name: self.name,
entry_condition: self.entry_condition,
exit_condition: condition,
short_entry_condition: self.short_entry_condition,
short_exit_condition: self.short_exit_condition,
regime_filter: self.regime_filter,
warmup_override: self.warmup_override,
}
}
}
impl<E, X> StrategyBuilder<E, X> {
pub fn regime_filter<C: Condition>(mut self, condition: C) -> Self {
self.regime_filter = Some(BoxedCondition::new(condition));
self
}
}
impl<E: Condition, X: Condition> StrategyBuilder<E, X> {
pub fn with_short<SE: Condition, SX: Condition>(mut self, entry: SE, exit: SX) -> Self {
self.short_entry_condition = Some(BoxedCondition::new(entry));
self.short_exit_condition = Some(BoxedCondition::new(exit));
self
}
pub fn warmup(mut self, bars: usize) -> Self {
self.warmup_override = Some(bars);
self
}
pub fn build(self) -> CustomStrategy<E, X> {
CustomStrategy {
name: self.name,
entry_condition: self.entry_condition,
exit_condition: self.exit_condition,
short_entry_condition: self.short_entry_condition,
short_exit_condition: self.short_exit_condition,
regime_filter: self.regime_filter,
warmup_override: self.warmup_override,
}
}
}
pub struct CustomStrategy<E: Condition, X: Condition> {
name: String,
entry_condition: E,
exit_condition: X,
short_entry_condition: Option<BoxedCondition>,
short_exit_condition: Option<BoxedCondition>,
regime_filter: Option<BoxedCondition>,
warmup_override: Option<usize>,
}
impl<E: Condition, X: Condition> Strategy for CustomStrategy<E, X> {
fn name(&self) -> &str {
&self.name
}
fn required_indicators(&self) -> Vec<(String, Indicator)> {
let mut indicators = self.entry_condition.required_indicators();
indicators.extend(self.exit_condition.required_indicators());
if let Some(ref se) = self.short_entry_condition {
indicators.extend(se.required_indicators().iter().cloned());
}
if let Some(ref sx) = self.short_exit_condition {
indicators.extend(sx.required_indicators().iter().cloned());
}
if let Some(ref rf) = self.regime_filter {
indicators.extend(rf.required_indicators().iter().cloned());
}
let mut seen = HashSet::new();
indicators.retain(|(key, _)| seen.insert(key.clone()));
indicators
}
fn htf_requirements(&self) -> Vec<HtfIndicatorSpec> {
let mut reqs = self.entry_condition.htf_requirements();
reqs.extend(self.exit_condition.htf_requirements());
if let Some(ref se) = self.short_entry_condition {
reqs.extend(se.htf_requirements().iter().cloned());
}
if let Some(ref sx) = self.short_exit_condition {
reqs.extend(sx.htf_requirements().iter().cloned());
}
if let Some(ref rf) = self.regime_filter {
reqs.extend(rf.htf_requirements().iter().cloned());
}
let mut seen = HashSet::new();
reqs.retain(|spec| seen.insert(spec.htf_key.clone()));
reqs
}
fn warmup_period(&self) -> usize {
if let Some(n) = self.warmup_override {
return n;
}
let max_warmup = self
.required_indicators()
.iter()
.map(|(_, indicator)| indicator.warmup_bars())
.max()
.unwrap_or(1);
max_warmup + 1
}
fn on_candle(&self, ctx: &StrategyContext) -> Signal {
let candle = ctx.current_candle();
if ctx.is_long() && self.exit_condition.evaluate(ctx) {
return Signal::exit(candle.timestamp, candle.close)
.with_reason(self.exit_condition.description());
}
if ctx.is_short()
&& let Some(ref exit) = self.short_exit_condition
&& exit.evaluate(ctx)
{
return Signal::exit(candle.timestamp, candle.close)
.with_reason(exit.description().to_string());
}
if !ctx.has_position() {
let regime_ok = self
.regime_filter
.as_ref()
.is_none_or(|rf| rf.evaluate(ctx));
if regime_ok {
if self.entry_condition.evaluate(ctx) {
return Signal::long(candle.timestamp, candle.close)
.with_reason(self.entry_condition.description());
}
if let Some(ref entry) = self.short_entry_condition
&& entry.evaluate(ctx)
{
return Signal::short(candle.timestamp, candle.close)
.with_reason(entry.description().to_string());
}
}
}
Signal::hold()
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use crate::backtesting::condition::{always_false, always_true};
use crate::backtesting::signal::SignalDirection;
use crate::models::chart::Candle;
fn make_candle(ts: i64, close: f64) -> Candle {
Candle {
timestamp: ts,
open: close,
high: close,
low: close,
close,
volume: 1000,
adj_close: None,
}
}
fn make_ctx<'a>(
candles: &'a [Candle],
indicators: &'a HashMap<String, Vec<Option<f64>>>,
) -> StrategyContext<'a> {
StrategyContext {
candles,
index: 0,
position: None,
equity: 10_000.0,
indicators,
}
}
#[test]
fn test_strategy_builder() {
let strategy = StrategyBuilder::new("Test Strategy")
.entry(always_true())
.exit(always_false())
.build();
assert_eq!(strategy.name(), "Test Strategy");
}
#[test]
fn test_strategy_builder_with_short() {
let strategy = StrategyBuilder::new("Test Strategy")
.entry(always_true())
.exit(always_false())
.with_short(always_false(), always_true())
.build();
assert_eq!(strategy.name(), "Test Strategy");
assert!(strategy.short_entry_condition.is_some());
assert!(strategy.short_exit_condition.is_some());
}
#[test]
fn test_required_indicators_deduplication() {
use crate::backtesting::condition::Above;
use crate::backtesting::refs::rsi;
let entry = Above::new(rsi(14), 70.0);
let exit = Above::new(rsi(14), 30.0);
let strategy = StrategyBuilder::new("Test").entry(entry).exit(exit).build();
let indicators = strategy.required_indicators();
assert_eq!(indicators.len(), 1);
assert_eq!(indicators[0].0, "rsi_14");
}
#[test]
fn test_regime_filter_suppresses_entry_when_false() {
let strategy = StrategyBuilder::new("Regime Test")
.regime_filter(always_false()) .entry(always_true())
.exit(always_false())
.build();
let candles = vec![make_candle(1, 100.0)];
let indicators = HashMap::new();
let ctx = make_ctx(&candles, &indicators);
assert_eq!(strategy.on_candle(&ctx).direction, SignalDirection::Hold);
}
#[test]
fn test_regime_filter_allows_entry_when_true() {
let strategy = StrategyBuilder::new("Regime Test")
.regime_filter(always_true()) .entry(always_true())
.exit(always_false())
.build();
let candles = vec![make_candle(1, 100.0)];
let indicators = HashMap::new();
let ctx = make_ctx(&candles, &indicators);
assert_eq!(strategy.on_candle(&ctx).direction, SignalDirection::Long);
}
#[test]
fn test_no_regime_filter_behaves_normally() {
let strategy = StrategyBuilder::new("No Regime")
.entry(always_true())
.exit(always_false())
.build();
let candles = vec![make_candle(1, 100.0)];
let indicators = HashMap::new();
let ctx = make_ctx(&candles, &indicators);
assert_eq!(strategy.on_candle(&ctx).direction, SignalDirection::Long);
}
#[test]
fn test_regime_filter_does_not_block_exit() {
use crate::backtesting::position::{Position, PositionSide};
let strategy = StrategyBuilder::new("Regime Exit Test")
.regime_filter(always_false()) .entry(always_false())
.exit(always_true()) .build();
let candles = vec![make_candle(1, 100.0)];
let indicators = HashMap::new();
let position = Position::new(
PositionSide::Long,
1,
90.0,
10.0,
0.0,
Signal::long(1, 90.0),
);
let ctx = StrategyContext {
candles: &candles,
index: 0,
position: Some(&position),
equity: 10_000.0,
indicators: &indicators,
};
assert_eq!(strategy.on_candle(&ctx).direction, SignalDirection::Exit);
}
#[test]
fn test_regime_filter_indicators_included_in_required() {
use crate::backtesting::refs::{IndicatorRefExt, sma};
use crate::indicators::Indicator;
let strategy = StrategyBuilder::new("Regime Indicators")
.regime_filter(sma(200).above_ref(sma(400)))
.entry(always_true())
.exit(always_false())
.build();
let indicators = strategy.required_indicators();
let keys: Vec<&str> = indicators.iter().map(|(k, _)| k.as_str()).collect();
assert!(
keys.contains(&"sma_200"),
"sma_200 must be in required_indicators"
);
assert!(
keys.contains(&"sma_400"),
"sma_400 must be in required_indicators"
);
let sma_200 = indicators.iter().find(|(k, _)| k == "sma_200").unwrap();
assert!(matches!(sma_200.1, Indicator::Sma(200)));
}
#[test]
fn test_regime_filter_callable_before_entry() {
let strategy = StrategyBuilder::new("Order Test")
.regime_filter(always_true())
.entry(always_true())
.exit(always_false())
.build();
assert!(strategy.regime_filter.is_some());
}
#[test]
fn test_regime_filter_warmup_accounts_for_filter_indicators() {
use crate::backtesting::refs::{IndicatorRefExt, sma};
let strategy = StrategyBuilder::new("Warmup Test")
.regime_filter(sma(400).above_ref(sma(200)))
.entry(always_true())
.exit(always_false())
.build();
assert!(
strategy.warmup_period() >= 401,
"warmup_period must account for sma(400): got {}",
strategy.warmup_period()
);
}
}