mod builder;
mod ensemble;
pub mod prebuilt;
use std::collections::HashMap;
use crate::backtesting::condition::HtfIndicatorSpec;
use crate::indicators::Indicator;
use crate::models::chart::Candle;
use super::position::{Position, PositionSide};
use super::signal::Signal;
pub use builder::{CustomStrategy, StrategyBuilder};
pub use ensemble::{EnsembleMode, EnsembleStrategy};
pub use prebuilt::{
BollingerMeanReversion, DonchianBreakout, MacdSignal, RsiReversal, SmaCrossover,
SuperTrendFollow,
};
#[non_exhaustive]
pub struct StrategyContext<'a> {
pub candles: &'a [Candle],
pub index: usize,
pub position: Option<&'a Position>,
pub equity: f64,
pub indicators: &'a HashMap<String, Vec<Option<f64>>>,
}
impl<'a> StrategyContext<'a> {
pub fn current_candle(&self) -> &Candle {
&self.candles[self.index]
}
pub fn previous_candle(&self) -> Option<&Candle> {
if self.index > 0 {
Some(&self.candles[self.index - 1])
} else {
None
}
}
pub fn candle_at(&self, index: usize) -> Option<&Candle> {
self.candles.get(index)
}
pub fn indicator(&self, name: &str) -> Option<f64> {
self.indicators
.get(name)
.and_then(|v| v.get(self.index))
.and_then(|&v| v)
}
pub fn indicator_at(&self, name: &str, index: usize) -> Option<f64> {
self.indicators
.get(name)
.and_then(|v| v.get(index))
.and_then(|&v| v)
}
pub fn indicator_prev(&self, name: &str) -> Option<f64> {
if self.index > 0 {
self.indicator_at(name, self.index - 1)
} else {
None
}
}
pub fn has_position(&self) -> bool {
self.position.is_some()
}
pub fn is_long(&self) -> bool {
self.position
.map(|p| matches!(p.side, PositionSide::Long))
.unwrap_or(false)
}
pub fn is_short(&self) -> bool {
self.position
.map(|p| matches!(p.side, PositionSide::Short))
.unwrap_or(false)
}
pub fn close(&self) -> f64 {
self.current_candle().close
}
pub fn open(&self) -> f64 {
self.current_candle().open
}
pub fn high(&self) -> f64 {
self.current_candle().high
}
pub fn low(&self) -> f64 {
self.current_candle().low
}
pub fn volume(&self) -> i64 {
self.current_candle().volume
}
pub fn timestamp(&self) -> i64 {
self.current_candle().timestamp
}
pub fn signal_long(&self) -> Signal {
Signal::long(self.timestamp(), self.close())
}
pub fn signal_short(&self) -> Signal {
Signal::short(self.timestamp(), self.close())
}
pub fn signal_exit(&self) -> Signal {
Signal::exit(self.timestamp(), self.close())
}
pub fn crossed_above(&self, fast_name: &str, slow_name: &str) -> bool {
let fast_now = self.indicator(fast_name);
let slow_now = self.indicator(slow_name);
let fast_prev = self.indicator_prev(fast_name);
let slow_prev = self.indicator_prev(slow_name);
match (fast_now, slow_now, fast_prev, slow_prev) {
(Some(f), Some(s), Some(fp), Some(sp)) => fp < sp && f > s, _ => false,
}
}
pub fn crossed_below(&self, fast_name: &str, slow_name: &str) -> bool {
let fast_now = self.indicator(fast_name);
let slow_now = self.indicator(slow_name);
let fast_prev = self.indicator_prev(fast_name);
let slow_prev = self.indicator_prev(slow_name);
match (fast_now, slow_now, fast_prev, slow_prev) {
(Some(f), Some(s), Some(fp), Some(sp)) => fp > sp && f < s, _ => false,
}
}
pub fn indicator_crossed_above(&self, name: &str, threshold: f64) -> bool {
let now = self.indicator(name);
let prev = self.indicator_prev(name);
match (now, prev) {
(Some(n), Some(p)) => p <= threshold && n > threshold,
_ => false,
}
}
pub fn indicator_crossed_below(&self, name: &str, threshold: f64) -> bool {
let now = self.indicator(name);
let prev = self.indicator_prev(name);
match (now, prev) {
(Some(n), Some(p)) => p >= threshold && n < threshold,
_ => false,
}
}
}
pub trait Strategy: Send + Sync {
fn name(&self) -> &str;
fn required_indicators(&self) -> Vec<(String, Indicator)>;
fn htf_requirements(&self) -> Vec<HtfIndicatorSpec> {
vec![]
}
fn setup(&mut self, _indicators: &HashMap<String, Vec<Option<f64>>>) {}
fn on_candle(&self, ctx: &StrategyContext) -> Signal;
fn warmup_period(&self) -> usize {
1
}
}
impl Strategy for Box<dyn Strategy> {
fn name(&self) -> &str {
(**self).name()
}
fn required_indicators(&self) -> Vec<(String, Indicator)> {
(**self).required_indicators()
}
fn htf_requirements(&self) -> Vec<HtfIndicatorSpec> {
(**self).htf_requirements()
}
fn setup(&mut self, indicators: &HashMap<String, Vec<Option<f64>>>) {
(**self).setup(indicators)
}
fn on_candle(&self, ctx: &StrategyContext) -> Signal {
(**self).on_candle(ctx)
}
fn warmup_period(&self) -> usize {
(**self).warmup_period()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestStrategy;
impl Strategy for TestStrategy {
fn name(&self) -> &str {
"Test Strategy"
}
fn required_indicators(&self) -> Vec<(String, Indicator)> {
vec![("sma_10".to_string(), Indicator::Sma(10))]
}
fn on_candle(&self, ctx: &StrategyContext) -> Signal {
if ctx.index == 5 {
Signal::long(ctx.timestamp(), ctx.close())
} else {
Signal::hold()
}
}
}
#[test]
fn test_strategy_trait() {
let strategy = TestStrategy;
assert_eq!(strategy.name(), "Test Strategy");
assert_eq!(strategy.required_indicators().len(), 1);
assert_eq!(strategy.warmup_period(), 1);
}
#[test]
fn test_context_crossover_detection() {
let candles = vec![
Candle {
timestamp: 1,
open: 100.0,
high: 101.0,
low: 99.0,
close: 100.0,
volume: 1000,
adj_close: None,
},
Candle {
timestamp: 2,
open: 100.0,
high: 102.0,
low: 99.0,
close: 101.0,
volume: 1000,
adj_close: None,
},
];
let mut indicators = HashMap::new();
indicators.insert("fast".to_string(), vec![Some(9.0), Some(11.0)]);
indicators.insert("slow".to_string(), vec![Some(10.0), Some(10.0)]);
let ctx = StrategyContext {
candles: &candles,
index: 1,
position: None,
equity: 10000.0,
indicators: &indicators,
};
assert!(ctx.crossed_above("fast", "slow"));
assert!(!ctx.crossed_below("fast", "slow"));
}
}