nt_features/
technical.rs

1// Technical indicators implementation
2//
3// Includes: SMA, EMA, RSI, MACD, Bollinger Bands
4// Performance target: <1ms per indicator update
5
6use crate::{FeatureError, Result};
7use rust_decimal::Decimal;
8use std::collections::VecDeque;
9
10pub struct TechnicalIndicators {
11    config: IndicatorConfig,
12    price_window: VecDeque<Decimal>,
13}
14
15#[derive(Debug, Clone)]
16pub struct IndicatorConfig {
17    pub sma_period: usize,
18    pub ema_period: usize,
19    pub rsi_period: usize,
20    pub macd_fast: usize,
21    pub macd_slow: usize,
22    pub macd_signal: usize,
23    pub bb_period: usize,
24    pub bb_std_dev: f64,
25}
26
27impl Default for IndicatorConfig {
28    fn default() -> Self {
29        Self {
30            sma_period: 20,
31            ema_period: 12,
32            rsi_period: 14,
33            macd_fast: 12,
34            macd_slow: 26,
35            macd_signal: 9,
36            bb_period: 20,
37            bb_std_dev: 2.0,
38        }
39    }
40}
41
42impl TechnicalIndicators {
43    pub fn new(config: IndicatorConfig) -> Self {
44        let max_window = config
45            .sma_period
46            .max(config.ema_period)
47            .max(config.rsi_period)
48            .max(config.macd_slow)
49            .max(config.bb_period);
50
51        Self {
52            config,
53            price_window: VecDeque::with_capacity(max_window),
54        }
55    }
56
57    pub fn add_price(&mut self, price: Decimal) {
58        let max_window = self.config.sma_period.max(self.config.macd_slow);
59
60        self.price_window.push_back(price);
61        if self.price_window.len() > max_window {
62            self.price_window.pop_front();
63        }
64    }
65
66    /// Simple Moving Average
67    pub fn sma(&self) -> Result<Decimal> {
68        if self.price_window.len() < self.config.sma_period {
69            return Err(FeatureError::InsufficientData(self.config.sma_period));
70        }
71
72        let sum: Decimal = self
73            .price_window
74            .iter()
75            .rev()
76            .take(self.config.sma_period)
77            .sum();
78
79        Ok(sum / Decimal::from(self.config.sma_period))
80    }
81
82    /// Exponential Moving Average
83    pub fn ema(&self, prices: &[Decimal], period: usize) -> Result<Decimal> {
84        if prices.len() < period {
85            return Err(FeatureError::InsufficientData(period));
86        }
87
88        let multiplier = Decimal::from(2) / Decimal::from(period + 1);
89        let sma: Decimal = prices.iter().take(period).sum::<Decimal>() / Decimal::from(period);
90
91        let mut ema = sma;
92        for &price in prices.iter().skip(period) {
93            ema = (price - ema) * multiplier + ema;
94        }
95
96        Ok(ema)
97    }
98
99    /// Relative Strength Index
100    pub fn rsi(&self) -> Result<Decimal> {
101        if self.price_window.len() < self.config.rsi_period + 1 {
102            return Err(FeatureError::InsufficientData(self.config.rsi_period + 1));
103        }
104
105        let mut gains = Vec::new();
106        let mut losses = Vec::new();
107
108        for window in self.price_window.iter().collect::<Vec<_>>().windows(2) {
109            let change = window[1] - window[0];
110            if change > Decimal::ZERO {
111                gains.push(change);
112                losses.push(Decimal::ZERO);
113            } else {
114                gains.push(Decimal::ZERO);
115                losses.push(-change);
116            }
117        }
118
119        let avg_gain: Decimal = gains
120            .iter()
121            .rev()
122            .take(self.config.rsi_period)
123            .sum::<Decimal>()
124            / Decimal::from(self.config.rsi_period);
125
126        let avg_loss: Decimal = losses
127            .iter()
128            .rev()
129            .take(self.config.rsi_period)
130            .sum::<Decimal>()
131            / Decimal::from(self.config.rsi_period);
132
133        if avg_loss == Decimal::ZERO {
134            return Ok(Decimal::from(100));
135        }
136
137        let rs = avg_gain / avg_loss;
138        let rsi = Decimal::from(100) - (Decimal::from(100) / (Decimal::ONE + rs));
139
140        Ok(rsi)
141    }
142
143    /// MACD (Moving Average Convergence Divergence)
144    pub fn macd(&self) -> Result<(Decimal, Decimal, Decimal)> {
145        if self.price_window.len() < self.config.macd_slow {
146            return Err(FeatureError::InsufficientData(self.config.macd_slow));
147        }
148
149        let prices: Vec<Decimal> = self.price_window.iter().copied().collect();
150
151        let fast_ema = self.ema(&prices, self.config.macd_fast)?;
152        let slow_ema = self.ema(&prices, self.config.macd_slow)?;
153        let macd_line = fast_ema - slow_ema;
154
155        // Signal line (9-period EMA of MACD)
156        // Simplified: return zero for signal and histogram
157        let signal_line = Decimal::ZERO;
158        let histogram = macd_line - signal_line;
159
160        Ok((macd_line, signal_line, histogram))
161    }
162
163    /// Bollinger Bands
164    pub fn bollinger_bands(&self) -> Result<(Decimal, Decimal, Decimal)> {
165        if self.price_window.len() < self.config.bb_period {
166            return Err(FeatureError::InsufficientData(self.config.bb_period));
167        }
168
169        let sma = self.sma()?;
170
171        let recent_prices: Vec<Decimal> = self
172            .price_window
173            .iter()
174            .rev()
175            .take(self.config.bb_period)
176            .copied()
177            .collect();
178
179        // Calculate standard deviation
180        let variance: Decimal = recent_prices
181            .iter()
182            .map(|&price| {
183                let diff = price - sma;
184                diff * diff
185            })
186            .sum::<Decimal>()
187            / Decimal::from(self.config.bb_period);
188
189        // Calculate standard deviation (convert to f64 for sqrt, then back to Decimal)
190        let variance_f64 = variance
191            .to_string()
192            .parse::<f64>()
193            .map_err(|_| FeatureError::Calculation("Cannot convert variance to f64".to_string()))?;
194        let std_dev_f64 = variance_f64.sqrt();
195        let std_dev = Decimal::try_from(std_dev_f64).map_err(|_| {
196            FeatureError::Calculation("Cannot convert sqrt result to Decimal".to_string())
197        })?;
198
199        let multiplier = Decimal::try_from(self.config.bb_std_dev)
200            .map_err(|_| FeatureError::Calculation("Invalid std_dev".to_string()))?;
201
202        let upper_band = sma + (std_dev * multiplier);
203        let lower_band = sma - (std_dev * multiplier);
204
205        Ok((upper_band, sma, lower_band))
206    }
207
208    /// Get all indicators at once
209    pub fn calculate_all(&self) -> Result<IndicatorValues> {
210        Ok(IndicatorValues {
211            sma: self.sma().ok(),
212            rsi: self.rsi().ok(),
213            macd: self.macd().ok(),
214            bollinger: self.bollinger_bands().ok(),
215        })
216    }
217}
218
219#[derive(Debug, Clone)]
220pub struct IndicatorValues {
221    pub sma: Option<Decimal>,
222    pub rsi: Option<Decimal>,
223    pub macd: Option<(Decimal, Decimal, Decimal)>,
224    pub bollinger: Option<(Decimal, Decimal, Decimal)>,
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use rust_decimal_macros::dec;
231
232    #[test]
233    fn test_sma() {
234        let _config = IndicatorConfig {
235            sma_period: 3,
236            ..Default::default()
237        };
238
239        let mut indicators = TechnicalIndicators::new(config);
240
241        indicators.add_price(dec!(100));
242        indicators.add_price(dec!(110));
243        indicators.add_price(dec!(120));
244
245        let sma = indicators.sma().unwrap();
246        assert_eq!(sma, dec!(110)); // (100 + 110 + 120) / 3
247    }
248
249    #[test]
250    fn test_rsi_all_gains() {
251        let _config = IndicatorConfig {
252            rsi_period: 3,
253            ..Default::default()
254        };
255
256        let mut indicators = TechnicalIndicators::new(config);
257
258        indicators.add_price(dec!(100));
259        indicators.add_price(dec!(105));
260        indicators.add_price(dec!(110));
261        indicators.add_price(dec!(115));
262
263        let rsi = indicators.rsi().unwrap();
264        assert_eq!(rsi, dec!(100)); // All gains, RSI = 100
265    }
266
267    #[test]
268    fn test_insufficient_data() {
269        let _config = IndicatorConfig {
270            sma_period: 5,
271            ..Default::default()
272        };
273
274        let mut indicators = TechnicalIndicators::new(config);
275
276        indicators.add_price(dec!(100));
277        indicators.add_price(dec!(110));
278
279        let result = indicators.sma();
280        assert!(result.is_err());
281    }
282
283    #[test]
284    fn test_bollinger_bands() {
285        let _config = IndicatorConfig {
286            bb_period: 3,
287            bb_std_dev: 2.0,
288            sma_period: 3,
289            ..Default::default()
290        };
291
292        let mut indicators = TechnicalIndicators::new(config);
293
294        indicators.add_price(dec!(100));
295        indicators.add_price(dec!(110));
296        indicators.add_price(dec!(120));
297
298        let (upper, middle, lower) = indicators.bollinger_bands().unwrap();
299
300        assert_eq!(middle, dec!(110)); // SMA
301        assert!(upper > middle);
302        assert!(lower < middle);
303    }
304}