1use std::collections::HashMap;
19
20use crate::error::IndicatorError;
21use crate::functions::{self};
22use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
23use crate::registry::{param_f64, param_str, param_usize};
24use crate::types::Candle;
25
26#[derive(Debug, Clone)]
29pub struct EmaParams {
30 pub period: usize,
32 pub alpha: Option<f64>,
35 pub column: PriceColumn,
37}
38
39impl Default for EmaParams {
40 fn default() -> Self {
41 Self {
42 period: 20,
43 alpha: None,
44 column: PriceColumn::Close,
45 }
46 }
47}
48
49impl EmaParams {
50 fn effective_alpha(&self) -> f64 {
51 self.alpha
52 .unwrap_or_else(|| 2.0 / (self.period as f64 + 1.0))
53 }
54}
55
56#[derive(Debug, Clone)]
59pub struct Ema {
60 pub params: EmaParams,
61}
62
63impl Ema {
64 pub fn new(params: EmaParams) -> Self {
65 Self { params }
66 }
67
68 pub fn with_period(period: usize) -> Self {
69 Self::new(EmaParams {
70 period,
71 ..Default::default()
72 })
73 }
74
75 fn output_key(&self) -> String {
76 format!("EMA_{}", self.params.period)
77 }
78}
79
80impl Indicator for Ema {
81 fn name(&self) -> &'static str {
82 "EMA"
83 }
84
85 fn required_len(&self) -> usize {
86 self.params.period
87 }
88
89 fn required_columns(&self) -> &[&'static str] {
90 &["close"]
91 }
92
93 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
99 self.check_len(candles)?;
100
101 let prices = self.params.column.extract(candles);
102 let alpha = self.params.effective_alpha();
103 let period = self.params.period;
104 let default_alpha = 2.0 / (period as f64 + 1.0);
105
106 let values = if (alpha - default_alpha).abs() < f64::EPSILON {
107 functions::ema(&prices, period)?
109 } else {
110 ema_with_alpha(&prices, period, alpha)?
112 };
113
114 Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
115 }
116}
117
118fn ema_with_alpha(prices: &[f64], period: usize, alpha: f64) -> Result<Vec<f64>, IndicatorError> {
126 if prices.len() < period {
127 return Err(IndicatorError::InsufficientData {
128 required: period,
129 available: prices.len(),
130 });
131 }
132 let mut result = vec![f64::NAN; prices.len()];
133 let seed: f64 = prices.iter().take(period).sum::<f64>() / period as f64;
134 result[period - 1] = seed;
135 let one_minus = 1.0 - alpha;
136 for i in period..prices.len() {
137 result[i] = prices[i] * alpha + result[i - 1] * one_minus;
138 }
139 Ok(result)
140}
141
142pub fn factory<S: ::std::hash::BuildHasher>(
145 params: &HashMap<String, String, S>,
146) -> Result<Box<dyn Indicator>, IndicatorError> {
147 let period = param_usize(params, "period", 20)?;
148 let alpha = if params.contains_key("alpha") {
149 Some(param_f64(params, "alpha", 2.0 / (period as f64 + 1.0))?)
150 } else {
151 None
152 };
153 let column = match param_str(params, "column", "close") {
154 "open" => PriceColumn::Open,
155 "high" => PriceColumn::High,
156 "low" => PriceColumn::Low,
157 _ => PriceColumn::Close,
158 };
159 Ok(Box::new(Ema::new(EmaParams {
160 period,
161 alpha,
162 column,
163 })))
164}
165
166#[cfg(test)]
169mod tests {
170 use super::*;
171
172 fn candles(closes: &[f64]) -> Vec<Candle> {
173 closes
174 .iter()
175 .enumerate()
176 .map(|(i, &c)| Candle {
177 time: i64::try_from(i).expect("time index fits i64"),
178 open: c,
179 high: c,
180 low: c,
181 close: c,
182 volume: 1.0,
183 })
184 .collect()
185 }
186
187 #[test]
188 fn ema_insufficient_data() {
189 let ema = Ema::with_period(5);
190 assert!(ema.calculate(&candles(&[1.0, 2.0])).is_err());
191 }
192
193 #[test]
194 fn ema_output_column_named_correctly() {
195 let ema = Ema::with_period(3);
196 let out = ema.calculate(&candles(&[10.0, 20.0, 30.0])).unwrap();
197 assert!(out.get("EMA_3").is_some());
198 }
199
200 #[test]
201 fn ema_seed_equals_sma() {
202 let closes = vec![10.0, 20.0, 30.0];
204 let ema = Ema::with_period(3);
205 let out = ema.calculate(&candles(&closes)).unwrap();
206 let vals = out.get("EMA_3").unwrap();
207 let expected_seed = (10.0 + 20.0 + 30.0) / 3.0;
208 assert!((vals[2] - expected_seed).abs() < 1e-9, "got {}", vals[2]);
209 }
210
211 #[test]
212 fn ema_subsequent_value() {
213 let closes = vec![10.0, 20.0, 30.0, 40.0];
215 let ema = Ema::with_period(3);
216 let out = ema.calculate(&candles(&closes)).unwrap();
217 let vals = out.get("EMA_3").unwrap();
218 let expected = 40.0 * 0.5 + 20.0 * 0.5;
219 assert!((vals[3] - expected).abs() < 1e-6, "got {}", vals[3]);
220 }
221
222 #[test]
223 fn ema_custom_alpha_differs_from_default() {
224 let closes = vec![10.0, 20.0, 30.0, 40.0];
226 let default_out = Ema::with_period(3).calculate(&candles(&closes)).unwrap();
227 let custom_out = Ema::new(EmaParams {
228 period: 3,
229 alpha: Some(0.1),
230 column: PriceColumn::Close,
231 })
232 .calculate(&candles(&closes))
233 .unwrap();
234 let d = default_out.get("EMA_3").unwrap();
235 let c = custom_out.get("EMA_3").unwrap();
236 assert!((c[2] - d[2]).abs() < 1e-9);
238 assert!(
240 (c[3] - d[3]).abs() > 1e-6,
241 "custom alpha should differ: {}",
242 c[3]
243 );
244 }
245
246 #[test]
247 fn ema_custom_alpha_correct_value() {
248 let closes = vec![10.0, 20.0, 30.0, 40.0];
250 let ema = Ema::new(EmaParams {
251 period: 3,
252 alpha: Some(0.1),
253 column: PriceColumn::Close,
254 });
255 let out = ema.calculate(&candles(&closes)).unwrap();
256 let vals = out.get("EMA_3").unwrap();
257 assert!((vals[3] - 22.0).abs() < 1e-9, "got {}", vals[3]);
258 }
259
260 #[test]
261 fn factory_creates_ema() {
262 let params = [("period".into(), "12".into())].into();
263 let ind = factory(¶ms).unwrap();
264 assert_eq!(ind.name(), "EMA");
265 }
266}