1use crate::{FeatureError, Result};
7use rust_decimal::Decimal;
8use std::collections::VecDeque;
9
10pub struct TechnicalIndicators {
11 config: IndicatorConfig,
12 price_window: VecDeque<Decimal>,
13}
14
15#[derive(Debug, Clone)]
16pub struct IndicatorConfig {
17 pub sma_period: usize,
18 pub ema_period: usize,
19 pub rsi_period: usize,
20 pub macd_fast: usize,
21 pub macd_slow: usize,
22 pub macd_signal: usize,
23 pub bb_period: usize,
24 pub bb_std_dev: f64,
25}
26
27impl Default for IndicatorConfig {
28 fn default() -> Self {
29 Self {
30 sma_period: 20,
31 ema_period: 12,
32 rsi_period: 14,
33 macd_fast: 12,
34 macd_slow: 26,
35 macd_signal: 9,
36 bb_period: 20,
37 bb_std_dev: 2.0,
38 }
39 }
40}
41
42impl TechnicalIndicators {
43 pub fn new(config: IndicatorConfig) -> Self {
44 let max_window = config
45 .sma_period
46 .max(config.ema_period)
47 .max(config.rsi_period)
48 .max(config.macd_slow)
49 .max(config.bb_period);
50
51 Self {
52 config,
53 price_window: VecDeque::with_capacity(max_window),
54 }
55 }
56
57 pub fn add_price(&mut self, price: Decimal) {
58 let max_window = self.config.sma_period.max(self.config.macd_slow);
59
60 self.price_window.push_back(price);
61 if self.price_window.len() > max_window {
62 self.price_window.pop_front();
63 }
64 }
65
66 pub fn sma(&self) -> Result<Decimal> {
68 if self.price_window.len() < self.config.sma_period {
69 return Err(FeatureError::InsufficientData(self.config.sma_period));
70 }
71
72 let sum: Decimal = self
73 .price_window
74 .iter()
75 .rev()
76 .take(self.config.sma_period)
77 .sum();
78
79 Ok(sum / Decimal::from(self.config.sma_period))
80 }
81
82 pub fn ema(&self, prices: &[Decimal], period: usize) -> Result<Decimal> {
84 if prices.len() < period {
85 return Err(FeatureError::InsufficientData(period));
86 }
87
88 let multiplier = Decimal::from(2) / Decimal::from(period + 1);
89 let sma: Decimal = prices.iter().take(period).sum::<Decimal>() / Decimal::from(period);
90
91 let mut ema = sma;
92 for &price in prices.iter().skip(period) {
93 ema = (price - ema) * multiplier + ema;
94 }
95
96 Ok(ema)
97 }
98
99 pub fn rsi(&self) -> Result<Decimal> {
101 if self.price_window.len() < self.config.rsi_period + 1 {
102 return Err(FeatureError::InsufficientData(self.config.rsi_period + 1));
103 }
104
105 let mut gains = Vec::new();
106 let mut losses = Vec::new();
107
108 for window in self.price_window.iter().collect::<Vec<_>>().windows(2) {
109 let change = window[1] - window[0];
110 if change > Decimal::ZERO {
111 gains.push(change);
112 losses.push(Decimal::ZERO);
113 } else {
114 gains.push(Decimal::ZERO);
115 losses.push(-change);
116 }
117 }
118
119 let avg_gain: Decimal = gains
120 .iter()
121 .rev()
122 .take(self.config.rsi_period)
123 .sum::<Decimal>()
124 / Decimal::from(self.config.rsi_period);
125
126 let avg_loss: Decimal = losses
127 .iter()
128 .rev()
129 .take(self.config.rsi_period)
130 .sum::<Decimal>()
131 / Decimal::from(self.config.rsi_period);
132
133 if avg_loss == Decimal::ZERO {
134 return Ok(Decimal::from(100));
135 }
136
137 let rs = avg_gain / avg_loss;
138 let rsi = Decimal::from(100) - (Decimal::from(100) / (Decimal::ONE + rs));
139
140 Ok(rsi)
141 }
142
143 pub fn macd(&self) -> Result<(Decimal, Decimal, Decimal)> {
145 if self.price_window.len() < self.config.macd_slow {
146 return Err(FeatureError::InsufficientData(self.config.macd_slow));
147 }
148
149 let prices: Vec<Decimal> = self.price_window.iter().copied().collect();
150
151 let fast_ema = self.ema(&prices, self.config.macd_fast)?;
152 let slow_ema = self.ema(&prices, self.config.macd_slow)?;
153 let macd_line = fast_ema - slow_ema;
154
155 let signal_line = Decimal::ZERO;
158 let histogram = macd_line - signal_line;
159
160 Ok((macd_line, signal_line, histogram))
161 }
162
163 pub fn bollinger_bands(&self) -> Result<(Decimal, Decimal, Decimal)> {
165 if self.price_window.len() < self.config.bb_period {
166 return Err(FeatureError::InsufficientData(self.config.bb_period));
167 }
168
169 let sma = self.sma()?;
170
171 let recent_prices: Vec<Decimal> = self
172 .price_window
173 .iter()
174 .rev()
175 .take(self.config.bb_period)
176 .copied()
177 .collect();
178
179 let variance: Decimal = recent_prices
181 .iter()
182 .map(|&price| {
183 let diff = price - sma;
184 diff * diff
185 })
186 .sum::<Decimal>()
187 / Decimal::from(self.config.bb_period);
188
189 let variance_f64 = variance
191 .to_string()
192 .parse::<f64>()
193 .map_err(|_| FeatureError::Calculation("Cannot convert variance to f64".to_string()))?;
194 let std_dev_f64 = variance_f64.sqrt();
195 let std_dev = Decimal::try_from(std_dev_f64).map_err(|_| {
196 FeatureError::Calculation("Cannot convert sqrt result to Decimal".to_string())
197 })?;
198
199 let multiplier = Decimal::try_from(self.config.bb_std_dev)
200 .map_err(|_| FeatureError::Calculation("Invalid std_dev".to_string()))?;
201
202 let upper_band = sma + (std_dev * multiplier);
203 let lower_band = sma - (std_dev * multiplier);
204
205 Ok((upper_band, sma, lower_band))
206 }
207
208 pub fn calculate_all(&self) -> Result<IndicatorValues> {
210 Ok(IndicatorValues {
211 sma: self.sma().ok(),
212 rsi: self.rsi().ok(),
213 macd: self.macd().ok(),
214 bollinger: self.bollinger_bands().ok(),
215 })
216 }
217}
218
219#[derive(Debug, Clone)]
220pub struct IndicatorValues {
221 pub sma: Option<Decimal>,
222 pub rsi: Option<Decimal>,
223 pub macd: Option<(Decimal, Decimal, Decimal)>,
224 pub bollinger: Option<(Decimal, Decimal, Decimal)>,
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use rust_decimal_macros::dec;
231
232 #[test]
233 fn test_sma() {
234 let _config = IndicatorConfig {
235 sma_period: 3,
236 ..Default::default()
237 };
238
239 let mut indicators = TechnicalIndicators::new(config);
240
241 indicators.add_price(dec!(100));
242 indicators.add_price(dec!(110));
243 indicators.add_price(dec!(120));
244
245 let sma = indicators.sma().unwrap();
246 assert_eq!(sma, dec!(110)); }
248
249 #[test]
250 fn test_rsi_all_gains() {
251 let _config = IndicatorConfig {
252 rsi_period: 3,
253 ..Default::default()
254 };
255
256 let mut indicators = TechnicalIndicators::new(config);
257
258 indicators.add_price(dec!(100));
259 indicators.add_price(dec!(105));
260 indicators.add_price(dec!(110));
261 indicators.add_price(dec!(115));
262
263 let rsi = indicators.rsi().unwrap();
264 assert_eq!(rsi, dec!(100)); }
266
267 #[test]
268 fn test_insufficient_data() {
269 let _config = IndicatorConfig {
270 sma_period: 5,
271 ..Default::default()
272 };
273
274 let mut indicators = TechnicalIndicators::new(config);
275
276 indicators.add_price(dec!(100));
277 indicators.add_price(dec!(110));
278
279 let result = indicators.sma();
280 assert!(result.is_err());
281 }
282
283 #[test]
284 fn test_bollinger_bands() {
285 let _config = IndicatorConfig {
286 bb_period: 3,
287 bb_std_dev: 2.0,
288 sma_period: 3,
289 ..Default::default()
290 };
291
292 let mut indicators = TechnicalIndicators::new(config);
293
294 indicators.add_price(dec!(100));
295 indicators.add_price(dec!(110));
296 indicators.add_price(dec!(120));
297
298 let (upper, middle, lower) = indicators.bollinger_bands().unwrap();
299
300 assert_eq!(middle, dec!(110)); assert!(upper > middle);
302 assert!(lower < middle);
303 }
304}