use std::collections::HashMap;
use crate::indicators::Indicator;
use super::{Signal, Strategy, StrategyContext};
use crate::backtesting::signal::SignalStrength;
#[derive(Debug, Default)]
struct IndicatorSlot(Option<*const Vec<Option<f64>>>);
unsafe impl Send for IndicatorSlot {}
unsafe impl Sync for IndicatorSlot {}
impl Clone for IndicatorSlot {
fn clone(&self) -> Self {
IndicatorSlot(None)
}
}
impl IndicatorSlot {
fn set(&mut self, v: &Vec<Option<f64>>) {
self.0 = Some(v as *const _);
}
#[inline]
unsafe fn get(&self) -> Option<&Vec<Option<f64>>> {
self.0.map(|p| unsafe { &*p })
}
}
#[derive(Debug, Clone)]
pub struct SmaCrossover {
pub fast_period: usize,
pub slow_period: usize,
fast_key: String,
slow_key: String,
fast_slot: IndicatorSlot,
slow_slot: IndicatorSlot,
}
impl SmaCrossover {
pub fn new(fast_period: usize, slow_period: usize) -> Self {
Self {
fast_period,
slow_period,
fast_key: format!("sma_{fast_period}"),
slow_key: format!("sma_{slow_period}"),
fast_slot: IndicatorSlot::default(),
slow_slot: IndicatorSlot::default(),
}
}
}
impl Default for SmaCrossover {
fn default() -> Self {
Self::new(10, 20)
}
}
impl Strategy for SmaCrossover {
fn name(&self) -> &str {
"SMA Crossover"
}
fn required_indicators(&self) -> Vec<(String, Indicator)> {
vec![
(self.fast_key.clone(), Indicator::Sma(self.fast_period)),
(self.slow_key.clone(), Indicator::Sma(self.slow_period)),
]
}
fn setup(&mut self, indicators: &HashMap<String, Vec<Option<f64>>>) {
if let Some(v) = indicators.get(&self.fast_key) {
self.fast_slot.set(v);
}
if let Some(v) = indicators.get(&self.slow_key) {
self.slow_slot.set(v);
}
}
fn warmup_period(&self) -> usize {
self.slow_period.max(self.fast_period) + 1
}
fn on_candle(&self, ctx: &StrategyContext) -> Signal {
let candle = ctx.current_candle();
let i = ctx.index;
if i == 0 {
return Signal::hold();
}
let fast_vals =
unsafe { self.fast_slot.get() }.or_else(|| ctx.indicators.get(&self.fast_key));
let slow_vals =
unsafe { self.slow_slot.get() }.or_else(|| ctx.indicators.get(&self.slow_key));
let (Some(fast_vals), Some(slow_vals)) = (fast_vals, slow_vals) else {
return Signal::hold();
};
let get = |vals: &Vec<Option<f64>>, idx: usize| vals.get(idx).and_then(|&v| v);
let (Some(fn_), Some(sn), Some(fp), Some(sp)) = (
get(fast_vals, i),
get(slow_vals, i),
get(fast_vals, i - 1),
get(slow_vals, i - 1),
) else {
return Signal::hold();
};
if fp < sp && fn_ > sn {
if ctx.is_short() {
return Signal::exit(candle.timestamp, candle.close)
.with_reason("SMA bullish crossover - close short");
}
if !ctx.has_position() {
return Signal::long(candle.timestamp, candle.close)
.with_reason("SMA bullish crossover");
}
}
if fp > sp && fn_ < sn {
if ctx.is_long() {
return Signal::exit(candle.timestamp, candle.close)
.with_reason("SMA bearish crossover - close long");
}
if !ctx.has_position() {
return Signal::short(candle.timestamp, candle.close)
.with_reason("SMA bearish crossover");
}
}
Signal::hold()
}
}
#[derive(Debug, Clone)]
pub struct RsiReversal {
pub period: usize,
pub oversold: f64,
pub overbought: f64,
rsi_key: String,
rsi_slot: IndicatorSlot,
}
impl RsiReversal {
pub fn new(period: usize) -> Self {
Self {
period,
oversold: 30.0,
overbought: 70.0,
rsi_key: format!("rsi_{period}"),
rsi_slot: IndicatorSlot::default(),
}
}
pub fn with_thresholds(mut self, oversold: f64, overbought: f64) -> Self {
self.oversold = oversold;
self.overbought = overbought;
self
}
}
impl Default for RsiReversal {
fn default() -> Self {
Self::new(14)
}
}
impl Strategy for RsiReversal {
fn name(&self) -> &str {
"RSI Reversal"
}
fn required_indicators(&self) -> Vec<(String, Indicator)> {
vec![(self.rsi_key.clone(), Indicator::Rsi(self.period))]
}
fn setup(&mut self, indicators: &HashMap<String, Vec<Option<f64>>>) {
if let Some(v) = indicators.get(&self.rsi_key) {
self.rsi_slot.set(v);
}
}
fn warmup_period(&self) -> usize {
self.period + 1
}
fn on_candle(&self, ctx: &StrategyContext) -> Signal {
let candle = ctx.current_candle();
let i = ctx.index;
let rsi_vals = unsafe { self.rsi_slot.get() }.or_else(|| ctx.indicators.get(&self.rsi_key));
let Some(rsi_vals) = rsi_vals else {
return Signal::hold();
};
let get = |idx: usize| rsi_vals.get(idx).and_then(|&v| v);
let Some(rsi_val) = get(i) else {
return Signal::hold();
};
let rsi_prev = if i > 0 { get(i - 1) } else { None };
let strength = if !(20.0..=80.0).contains(&rsi_val) {
SignalStrength::strong()
} else if !(25.0..=75.0).contains(&rsi_val) {
SignalStrength::medium()
} else {
SignalStrength::weak()
};
let crossed_above_oversold =
rsi_prev.is_some_and(|p| p <= self.oversold) && rsi_val > self.oversold;
if crossed_above_oversold {
if ctx.is_short() {
return Signal::exit(candle.timestamp, candle.close)
.with_strength(strength)
.with_reason(format!(
"RSI crossed above {:.0} - close short",
self.oversold
));
}
if !ctx.has_position() {
return Signal::long(candle.timestamp, candle.close)
.with_strength(strength)
.with_reason(format!("RSI crossed above {:.0}", self.oversold));
}
}
let crossed_below_overbought =
rsi_prev.is_some_and(|p| p >= self.overbought) && rsi_val < self.overbought;
if crossed_below_overbought {
if ctx.is_long() {
return Signal::exit(candle.timestamp, candle.close)
.with_strength(strength)
.with_reason(format!(
"RSI crossed below {:.0} - close long",
self.overbought
));
}
if !ctx.has_position() {
return Signal::short(candle.timestamp, candle.close)
.with_strength(strength)
.with_reason(format!("RSI crossed below {:.0}", self.overbought));
}
}
Signal::hold()
}
}
#[derive(Debug, Clone)]
pub struct MacdSignal {
pub fast: usize,
pub slow: usize,
pub signal: usize,
line_key: String,
sig_key: String,
line_slot: IndicatorSlot,
sig_slot: IndicatorSlot,
}
impl MacdSignal {
pub fn new(fast: usize, slow: usize, signal: usize) -> Self {
Self {
fast,
slow,
signal,
line_key: format!("macd_line_{fast}_{slow}_{signal}"),
sig_key: format!("macd_signal_{fast}_{slow}_{signal}"),
line_slot: IndicatorSlot::default(),
sig_slot: IndicatorSlot::default(),
}
}
}
impl Default for MacdSignal {
fn default() -> Self {
Self::new(12, 26, 9)
}
}
impl Strategy for MacdSignal {
fn name(&self) -> &str {
"MACD Signal"
}
fn required_indicators(&self) -> Vec<(String, Indicator)> {
vec![(
"macd".to_string(),
Indicator::Macd {
fast: self.fast,
slow: self.slow,
signal: self.signal,
},
)]
}
fn setup(&mut self, indicators: &HashMap<String, Vec<Option<f64>>>) {
if let Some(v) = indicators.get(&self.line_key) {
self.line_slot.set(v);
}
if let Some(v) = indicators.get(&self.sig_key) {
self.sig_slot.set(v);
}
}
fn warmup_period(&self) -> usize {
self.slow + self.signal
}
fn on_candle(&self, ctx: &StrategyContext) -> Signal {
let candle = ctx.current_candle();
let i = ctx.index;
if i == 0 {
return Signal::hold();
}
let line_vals =
unsafe { self.line_slot.get() }.or_else(|| ctx.indicators.get(&self.line_key));
let sig_vals = unsafe { self.sig_slot.get() }.or_else(|| ctx.indicators.get(&self.sig_key));
let (Some(line_vals), Some(sig_vals)) = (line_vals, sig_vals) else {
return Signal::hold();
};
let get = |vals: &Vec<Option<f64>>, idx: usize| vals.get(idx).and_then(|&v| v);
let (Some(ln), Some(sn), Some(lp), Some(sp)) = (
get(line_vals, i),
get(sig_vals, i),
get(line_vals, i - 1),
get(sig_vals, i - 1),
) else {
return Signal::hold();
};
if lp < sp && ln > sn {
if ctx.is_short() {
return Signal::exit(candle.timestamp, candle.close)
.with_reason("MACD bullish crossover - close short");
}
if !ctx.has_position() {
return Signal::long(candle.timestamp, candle.close)
.with_reason("MACD bullish crossover");
}
}
if lp > sp && ln < sn {
if ctx.is_long() {
return Signal::exit(candle.timestamp, candle.close)
.with_reason("MACD bearish crossover - close long");
}
if !ctx.has_position() {
return Signal::short(candle.timestamp, candle.close)
.with_reason("MACD bearish crossover");
}
}
Signal::hold()
}
}
#[derive(Debug, Clone)]
pub struct BollingerMeanReversion {
pub period: usize,
pub std_dev: f64,
pub exit_at_middle: bool,
lower_key: String,
middle_key: String,
upper_key: String,
lower_slot: IndicatorSlot,
middle_slot: IndicatorSlot,
upper_slot: IndicatorSlot,
}
impl BollingerMeanReversion {
pub fn new(period: usize, std_dev: f64) -> Self {
Self {
period,
std_dev,
exit_at_middle: true,
lower_key: format!("bollinger_lower_{period}_{std_dev}"),
middle_key: format!("bollinger_middle_{period}_{std_dev}"),
upper_key: format!("bollinger_upper_{period}_{std_dev}"),
lower_slot: IndicatorSlot::default(),
middle_slot: IndicatorSlot::default(),
upper_slot: IndicatorSlot::default(),
}
}
pub fn exit_at_middle(mut self, at_middle: bool) -> Self {
self.exit_at_middle = at_middle;
self
}
}
impl Default for BollingerMeanReversion {
fn default() -> Self {
Self::new(20, 2.0)
}
}
impl Strategy for BollingerMeanReversion {
fn name(&self) -> &str {
"Bollinger Mean Reversion"
}
fn required_indicators(&self) -> Vec<(String, Indicator)> {
vec![(
"bollinger".to_string(),
Indicator::Bollinger {
period: self.period,
std_dev: self.std_dev,
},
)]
}
fn setup(&mut self, indicators: &HashMap<String, Vec<Option<f64>>>) {
if let Some(v) = indicators.get(&self.lower_key) {
self.lower_slot.set(v);
}
if let Some(v) = indicators.get(&self.middle_key) {
self.middle_slot.set(v);
}
if let Some(v) = indicators.get(&self.upper_key) {
self.upper_slot.set(v);
}
}
fn warmup_period(&self) -> usize {
self.period
}
fn on_candle(&self, ctx: &StrategyContext) -> Signal {
let candle = ctx.current_candle();
let close = candle.close;
let i = ctx.index;
let lower_vals =
unsafe { self.lower_slot.get() }.or_else(|| ctx.indicators.get(&self.lower_key));
let middle_vals =
unsafe { self.middle_slot.get() }.or_else(|| ctx.indicators.get(&self.middle_key));
let upper_vals =
unsafe { self.upper_slot.get() }.or_else(|| ctx.indicators.get(&self.upper_key));
let (Some(lower_vals), Some(middle_vals), Some(upper_vals)) =
(lower_vals, middle_vals, upper_vals)
else {
return Signal::hold();
};
let get = |vals: &Vec<Option<f64>>, idx: usize| vals.get(idx).and_then(|&v| v);
let (Some(lower_val), Some(middle_val), Some(upper_val)) =
(get(lower_vals, i), get(middle_vals, i), get(upper_vals, i))
else {
return Signal::hold();
};
if close <= lower_val && !ctx.has_position() {
return Signal::long(candle.timestamp, close)
.with_reason("Price at lower Bollinger Band");
}
if ctx.is_long() {
let exit_level = if self.exit_at_middle {
middle_val
} else {
upper_val
};
if close >= exit_level {
return Signal::exit(candle.timestamp, close).with_reason(format!(
"Price reached {} Bollinger Band",
if self.exit_at_middle {
"middle"
} else {
"upper"
}
));
}
}
if close >= upper_val && !ctx.has_position() {
return Signal::short(candle.timestamp, close)
.with_reason("Price at upper Bollinger Band");
}
if ctx.is_short() {
let exit_level = if self.exit_at_middle {
middle_val
} else {
lower_val
};
if close <= exit_level {
return Signal::exit(candle.timestamp, close).with_reason(format!(
"Price reached {} Bollinger Band",
if self.exit_at_middle {
"middle"
} else {
"lower"
}
));
}
}
Signal::hold()
}
}
#[derive(Debug, Clone)]
pub struct SuperTrendFollow {
pub period: usize,
pub multiplier: f64,
uptrend_key: String,
uptrend_slot: IndicatorSlot,
}
impl SuperTrendFollow {
pub fn new(period: usize, multiplier: f64) -> Self {
Self {
period,
multiplier,
uptrend_key: format!("supertrend_uptrend_{period}_{multiplier}"),
uptrend_slot: IndicatorSlot::default(),
}
}
}
impl Default for SuperTrendFollow {
fn default() -> Self {
Self::new(10, 3.0)
}
}
impl Strategy for SuperTrendFollow {
fn name(&self) -> &str {
"SuperTrend Follow"
}
fn required_indicators(&self) -> Vec<(String, Indicator)> {
vec![(
"supertrend".to_string(),
Indicator::Supertrend {
period: self.period,
multiplier: self.multiplier,
},
)]
}
fn setup(&mut self, indicators: &HashMap<String, Vec<Option<f64>>>) {
if let Some(v) = indicators.get(&self.uptrend_key) {
self.uptrend_slot.set(v);
}
}
fn warmup_period(&self) -> usize {
self.period + 1
}
fn on_candle(&self, ctx: &StrategyContext) -> Signal {
let candle = ctx.current_candle();
let i = ctx.index;
let vals =
unsafe { self.uptrend_slot.get() }.or_else(|| ctx.indicators.get(&self.uptrend_key));
let Some(vals) = vals else {
return Signal::hold();
};
let get = |idx: usize| vals.get(idx).and_then(|&v| v);
let (Some(now), Some(prev)) = (get(i), if i > 0 { get(i - 1) } else { None }) else {
return Signal::hold();
};
let is_uptrend = now > 0.5;
let was_uptrend = prev > 0.5;
if is_uptrend && !was_uptrend {
if ctx.is_short() {
return Signal::exit(candle.timestamp, candle.close)
.with_reason("SuperTrend turned bullish - close short");
}
if !ctx.has_position() {
return Signal::long(candle.timestamp, candle.close)
.with_reason("SuperTrend turned bullish");
}
}
if !is_uptrend && was_uptrend {
if ctx.is_long() {
return Signal::exit(candle.timestamp, candle.close)
.with_reason("SuperTrend turned bearish - close long");
}
if !ctx.has_position() {
return Signal::short(candle.timestamp, candle.close)
.with_reason("SuperTrend turned bearish");
}
}
Signal::hold()
}
}
#[derive(Debug, Clone)]
pub struct DonchianBreakout {
pub period: usize,
pub exit_at_middle: bool,
upper_key: String,
middle_key: String,
lower_key: String,
upper_slot: IndicatorSlot,
middle_slot: IndicatorSlot,
lower_slot: IndicatorSlot,
}
impl DonchianBreakout {
pub fn new(period: usize) -> Self {
Self {
period,
exit_at_middle: true,
upper_key: format!("donchian_upper_{period}"),
middle_key: format!("donchian_middle_{period}"),
lower_key: format!("donchian_lower_{period}"),
upper_slot: IndicatorSlot::default(),
middle_slot: IndicatorSlot::default(),
lower_slot: IndicatorSlot::default(),
}
}
pub fn exit_at_middle(mut self, at_middle: bool) -> Self {
self.exit_at_middle = at_middle;
self
}
}
impl Default for DonchianBreakout {
fn default() -> Self {
Self::new(20)
}
}
impl Strategy for DonchianBreakout {
fn name(&self) -> &str {
"Donchian Breakout"
}
fn required_indicators(&self) -> Vec<(String, Indicator)> {
vec![(
"donchian".to_string(),
Indicator::DonchianChannels(self.period),
)]
}
fn setup(&mut self, indicators: &HashMap<String, Vec<Option<f64>>>) {
if let Some(v) = indicators.get(&self.upper_key) {
self.upper_slot.set(v);
}
if let Some(v) = indicators.get(&self.middle_key) {
self.middle_slot.set(v);
}
if let Some(v) = indicators.get(&self.lower_key) {
self.lower_slot.set(v);
}
}
fn warmup_period(&self) -> usize {
self.period
}
fn on_candle(&self, ctx: &StrategyContext) -> Signal {
let candle = ctx.current_candle();
let close = candle.close;
let i = ctx.index;
let upper_vals =
unsafe { self.upper_slot.get() }.or_else(|| ctx.indicators.get(&self.upper_key));
let middle_vals =
unsafe { self.middle_slot.get() }.or_else(|| ctx.indicators.get(&self.middle_key));
let lower_vals =
unsafe { self.lower_slot.get() }.or_else(|| ctx.indicators.get(&self.lower_key));
let (Some(upper_vals), Some(middle_vals), Some(lower_vals)) =
(upper_vals, middle_vals, lower_vals)
else {
return Signal::hold();
};
let get = |vals: &Vec<Option<f64>>, idx: usize| vals.get(idx).and_then(|&v| v);
let (Some(_upper_val), Some(middle_val), Some(_lower_val)) =
(get(upper_vals, i), get(middle_vals, i), get(lower_vals, i))
else {
return Signal::hold();
};
let prev_upper = if i > 0 { get(upper_vals, i - 1) } else { None };
let prev_lower = if i > 0 { get(lower_vals, i - 1) } else { None };
if let Some(prev_up) = prev_upper
&& close > prev_up
&& !ctx.has_position()
{
return Signal::long(candle.timestamp, close)
.with_reason("Donchian upper channel breakout");
}
if let Some(prev_low) = prev_lower
&& close < prev_low
{
if ctx.is_long() {
return Signal::exit(candle.timestamp, close)
.with_reason("Donchian lower channel breakdown - close long");
}
if !ctx.has_position() {
return Signal::short(candle.timestamp, close)
.with_reason("Donchian lower channel breakdown");
}
}
if ctx.is_long() && self.exit_at_middle && close <= middle_val {
return Signal::exit(candle.timestamp, close)
.with_reason("Price reached Donchian middle channel");
}
if ctx.is_short() && self.exit_at_middle && close >= middle_val {
return Signal::exit(candle.timestamp, close)
.with_reason("Price reached Donchian middle channel");
}
Signal::hold()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sma_crossover_default() {
let s = SmaCrossover::default();
assert_eq!(s.fast_period, 10);
assert_eq!(s.slow_period, 20);
}
#[test]
fn test_sma_crossover_custom() {
let s = SmaCrossover::new(5, 15);
assert_eq!(s.fast_period, 5);
assert_eq!(s.slow_period, 15);
}
#[test]
fn test_rsi_default() {
let s = RsiReversal::default();
assert_eq!(s.period, 14);
assert!((s.oversold - 30.0).abs() < 0.01);
assert!((s.overbought - 70.0).abs() < 0.01);
}
#[test]
fn test_rsi_with_thresholds() {
let s = RsiReversal::new(10).with_thresholds(25.0, 75.0);
assert_eq!(s.period, 10);
assert!((s.oversold - 25.0).abs() < 0.01);
assert!((s.overbought - 75.0).abs() < 0.01);
}
#[test]
fn test_macd_default() {
let s = MacdSignal::default();
assert_eq!(s.fast, 12);
assert_eq!(s.slow, 26);
assert_eq!(s.signal, 9);
}
#[test]
fn test_bollinger_default() {
let s = BollingerMeanReversion::default();
assert_eq!(s.period, 20);
assert!((s.std_dev - 2.0).abs() < 0.01);
}
#[test]
fn test_supertrend_default() {
let s = SuperTrendFollow::default();
assert_eq!(s.period, 10);
assert!((s.multiplier - 3.0).abs() < 0.01);
}
#[test]
fn test_donchian_default() {
let s = DonchianBreakout::default();
assert_eq!(s.period, 20);
assert!(s.exit_at_middle);
}
#[test]
fn test_strategy_names() {
assert_eq!(SmaCrossover::default().name(), "SMA Crossover");
assert_eq!(RsiReversal::default().name(), "RSI Reversal");
assert_eq!(MacdSignal::default().name(), "MACD Signal");
assert_eq!(
BollingerMeanReversion::default().name(),
"Bollinger Mean Reversion"
);
assert_eq!(SuperTrendFollow::default().name(), "SuperTrend Follow");
assert_eq!(DonchianBreakout::default().name(), "Donchian Breakout");
}
#[test]
fn test_required_indicators() {
let sma = SmaCrossover::new(5, 10);
let indicators = sma.required_indicators();
assert_eq!(indicators.len(), 2);
assert_eq!(indicators[0].0, "sma_5");
assert_eq!(indicators[1].0, "sma_10");
let rsi = RsiReversal::new(14);
let indicators = rsi.required_indicators();
assert_eq!(indicators.len(), 1);
assert_eq!(indicators[0].0, "rsi_14");
}
}