use super::portfolio::Portfolio;
use super::trade::Trade;
use crate::model::Bar;
use chrono::{DateTime, Utc};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SignalType {
Buy,
Sell,
Short,
Cover,
Exit,
}
#[derive(Debug, Clone)]
pub struct Signal {
pub signal_type: SignalType,
pub symbol: String,
pub quantity: Option<f64>,
pub price: Option<f64>,
pub ts: DateTime<Utc>,
pub reason: Option<String>,
}
impl Signal {
pub fn buy(symbol: impl Into<String>) -> Self {
Self {
signal_type: SignalType::Buy,
symbol: symbol.into(),
quantity: None,
price: None,
ts: Utc::now(),
reason: None,
}
}
pub fn sell(symbol: impl Into<String>) -> Self {
Self {
signal_type: SignalType::Sell,
symbol: symbol.into(),
quantity: None,
price: None,
ts: Utc::now(),
reason: None,
}
}
pub fn short(symbol: impl Into<String>) -> Self {
Self {
signal_type: SignalType::Short,
symbol: symbol.into(),
quantity: None,
price: None,
ts: Utc::now(),
reason: None,
}
}
pub fn cover(symbol: impl Into<String>) -> Self {
Self {
signal_type: SignalType::Cover,
symbol: symbol.into(),
quantity: None,
price: None,
ts: Utc::now(),
reason: None,
}
}
pub fn exit(symbol: impl Into<String>) -> Self {
Self {
signal_type: SignalType::Exit,
symbol: symbol.into(),
quantity: None,
price: None,
ts: Utc::now(),
reason: None,
}
}
pub fn with_quantity(mut self, qty: f64) -> Self {
self.quantity = Some(qty);
self
}
pub fn with_price(mut self, price: f64) -> Self {
self.price = Some(price);
self
}
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
self.ts = ts;
self
}
pub fn with_reason(mut self, reason: impl Into<String>) -> Self {
self.reason = Some(reason.into());
self
}
}
#[derive(Debug)]
pub struct StrategyContext<'a> {
pub bar_idx: usize,
pub bar: &'a Bar,
pub bars: &'a [Bar],
pub portfolio: &'a Portfolio,
pub symbol: &'a str,
}
impl<'a> StrategyContext<'a> {
pub fn bar_at(&self, offset: isize) -> Option<&Bar> {
let idx = self.bar_idx as isize + offset;
if idx >= 0 && (idx as usize) < self.bars.len() {
Some(&self.bars[idx as usize])
} else {
None
}
}
pub fn close(&self, offset: isize) -> Option<f64> {
self.bar_at(offset).map(|b| b.close)
}
pub fn high(&self, offset: isize) -> Option<f64> {
self.bar_at(offset).map(|b| b.high)
}
pub fn low(&self, offset: isize) -> Option<f64> {
self.bar_at(offset).map(|b| b.low)
}
pub fn open(&self, offset: isize) -> Option<f64> {
self.bar_at(offset).map(|b| b.open)
}
pub fn volume(&self, offset: isize) -> Option<f64> {
self.bar_at(offset).map(|b| b.volume)
}
pub fn sma(&self, period: usize) -> Option<f64> {
if self.bar_idx + 1 < period {
return None;
}
let start = self.bar_idx + 1 - period;
let sum: f64 = self.bars[start..=self.bar_idx]
.iter()
.map(|b| b.close)
.sum();
Some(sum / period as f64)
}
pub fn highest_high(&self, period: usize) -> Option<f64> {
if self.bar_idx + 1 < period {
return None;
}
let start = self.bar_idx + 1 - period;
self.bars[start..=self.bar_idx]
.iter()
.map(|b| b.high)
.fold(None, |acc, h| match acc {
None => Some(h),
Some(max) => Some(f64::max(max, h)),
})
}
pub fn lowest_low(&self, period: usize) -> Option<f64> {
if self.bar_idx + 1 < period {
return None;
}
let start = self.bar_idx + 1 - period;
self.bars[start..=self.bar_idx]
.iter()
.map(|b| b.low)
.fold(None, |acc, l| match acc {
None => Some(l),
Some(min) => Some(f64::min(min, l)),
})
}
pub fn is_long(&self) -> bool {
self.portfolio
.get_pos(self.symbol)
.map(|p| matches!(p.side, super::portfolio::PosSide::Long))
.unwrap_or(false)
}
pub fn is_short(&self) -> bool {
self.portfolio
.get_pos(self.symbol)
.map(|p| matches!(p.side, super::portfolio::PosSide::Short))
.unwrap_or(false)
}
pub fn is_flat(&self) -> bool {
!self.portfolio.has_pos(self.symbol)
}
pub fn pos_size(&self) -> f64 {
self.portfolio
.get_pos(self.symbol)
.map(|p| p.quantity)
.unwrap_or(0.0)
}
pub fn equity(&self) -> f64 {
self.portfolio.equity()
}
pub fn unrealized_pnl(&self) -> f64 {
self.portfolio
.get_pos(self.symbol)
.map(|p| p.unrealized_pnl)
.unwrap_or(0.0)
}
}
pub trait Strategy: Send {
fn init(&mut self, _data: &[Bar]) {}
fn on_bar(&mut self, ctx: &StrategyContext) -> Vec<Signal>;
fn on_fill(&mut self, _trade: &Trade) {}
fn name(&self) -> &str;
fn params(&self) -> Vec<(String, String)> {
Vec::new()
}
}
pub struct SmaCrossover {
fast_period: usize,
slow_period: usize,
fast_sma: Vec<f64>,
slow_sma: Vec<f64>,
}
impl SmaCrossover {
pub fn new(fast_period: usize, slow_period: usize) -> Self {
Self {
fast_period,
slow_period,
fast_sma: Vec::new(),
slow_sma: Vec::new(),
}
}
fn calculate_sma(data: &[Bar], period: usize) -> Vec<f64> {
let mut sma = Vec::with_capacity(data.len());
for i in 0..data.len() {
if i + 1 < period {
sma.push(f64::NAN);
} else {
let sum: f64 = data[i + 1 - period..=i].iter().map(|b| b.close).sum();
sma.push(sum / period as f64);
}
}
sma
}
}
impl Strategy for SmaCrossover {
fn name(&self) -> &str {
"SMA Crossover"
}
fn params(&self) -> Vec<(String, String)> {
vec![
("fast_period".to_string(), self.fast_period.to_string()),
("slow_period".to_string(), self.slow_period.to_string()),
]
}
fn init(&mut self, data: &[Bar]) {
self.fast_sma = Self::calculate_sma(data, self.fast_period);
self.slow_sma = Self::calculate_sma(data, self.slow_period);
}
fn on_bar(&mut self, ctx: &StrategyContext) -> Vec<Signal> {
let idx = ctx.bar_idx;
if idx < 1 {
return Vec::new();
}
let fast_current = self.fast_sma.get(idx).copied().unwrap_or(f64::NAN);
let fast_prev = self.fast_sma.get(idx - 1).copied().unwrap_or(f64::NAN);
let slow_current = self.slow_sma.get(idx).copied().unwrap_or(f64::NAN);
let slow_prev = self.slow_sma.get(idx - 1).copied().unwrap_or(f64::NAN);
if fast_current.is_nan() || slow_current.is_nan() {
return Vec::new();
}
let mut signals = Vec::new();
if fast_prev <= slow_prev && fast_current > slow_current {
if ctx.is_short() {
signals.push(Signal::cover(ctx.symbol).with_reason("Golden cross"));
}
if ctx.is_flat() || ctx.is_short() {
signals.push(Signal::buy(ctx.symbol).with_reason("Golden cross"));
}
}
if fast_prev >= slow_prev && fast_current < slow_current && ctx.is_long() {
signals.push(Signal::sell(ctx.symbol).with_reason("Death cross"));
}
signals
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
fn create_test_bars(prices: Vec<f64>) -> Vec<Bar> {
prices
.into_iter()
.map(|p| Bar {
time: Utc::now(),
open: p,
high: p + 1.0,
low: p - 1.0,
close: p,
volume: 1000.0,
})
.collect()
}
#[test]
fn test_signal_creation() {
let signal = Signal::buy("AAPL").with_quantity(100.0).with_reason("Test");
assert_eq!(signal.signal_type, SignalType::Buy);
assert_eq!(signal.symbol, "AAPL");
assert_eq!(signal.quantity, Some(100.0));
assert_eq!(signal.reason, Some("Test".to_string()));
}
#[test]
fn test_sma_crossover_strategy() {
let mut strategy = SmaCrossover::new(5, 10);
let bars = create_test_bars(vec![
100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0,
]);
strategy.init(&bars);
assert_eq!(strategy.fast_sma.len(), bars.len());
assert_eq!(strategy.slow_sma.len(), bars.len());
}
#[test]
fn test_ctx_sma() {
let bars = create_test_bars(vec![10.0, 20.0, 30.0, 40.0, 50.0]);
let portfolio = super::super::portfolio::Portfolio::new(100_000.0);
let ctx = StrategyContext {
bar_idx: 4,
bar: &bars[4],
bars: &bars,
portfolio: &portfolio,
symbol: "TEST",
};
let sma = ctx.sma(5).unwrap();
assert!((sma - 30.0).abs() < 0.01); }
}