1use std::collections::HashMap;
19
20use crate::error::IndicatorError;
21use crate::functions::{self};
22use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
23use crate::registry::{param_f64, param_str, param_usize};
24use crate::types::Candle;
25
26#[derive(Debug, Clone)]
29pub struct EmaParams {
30 pub period: usize,
32 pub alpha: Option<f64>,
35 pub column: PriceColumn,
37}
38
39impl Default for EmaParams {
40 fn default() -> Self {
41 Self {
42 period: 20,
43 alpha: None,
44 column: PriceColumn::Close,
45 }
46 }
47}
48
49impl EmaParams {
50 fn effective_alpha(&self) -> f64 {
51 self.alpha
52 .unwrap_or_else(|| 2.0 / (self.period as f64 + 1.0))
53 }
54}
55
56#[derive(Debug, Clone)]
59pub struct Ema {
60 pub params: EmaParams,
61}
62
63impl Ema {
64 pub fn new(params: EmaParams) -> Self {
65 Self { params }
66 }
67
68 pub fn with_period(period: usize) -> Self {
69 Self::new(EmaParams {
70 period,
71 ..Default::default()
72 })
73 }
74
75 fn output_key(&self) -> String {
76 format!("EMA_{}", self.params.period)
77 }
78}
79
80impl Indicator for Ema {
81 fn name(&self) -> &'static str {
82 "EMA"
83 }
84
85 fn required_len(&self) -> usize {
86 self.params.period
87 }
88
89 fn required_columns(&self) -> &[&'static str] {
90 &["close"]
91 }
92
93 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
98 self.check_len(candles)?;
99
100 let prices = self.params.column.extract(candles);
101 let _alpha = self.params.effective_alpha();
102 let _n = prices.len();
103 let period = self.params.period;
104
105 let values = functions::ema(&prices, period)?;
108
109 Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
113 }
114}
115
116pub fn factory<S: ::std::hash::BuildHasher>(params: &HashMap<String, String, S>) -> Result<Box<dyn Indicator>, IndicatorError> {
119 let period = param_usize(params, "period", 20)?;
120 let alpha = if params.contains_key("alpha") {
121 Some(param_f64(params, "alpha", 2.0 / (period as f64 + 1.0))?)
122 } else {
123 None
124 };
125 let column = match param_str(params, "column", "close") {
126 "open" => PriceColumn::Open,
127 "high" => PriceColumn::High,
128 "low" => PriceColumn::Low,
129 _ => PriceColumn::Close,
130 };
131 Ok(Box::new(Ema::new(EmaParams {
132 period,
133 alpha,
134 column,
135 })))
136}
137
138#[cfg(test)]
141mod tests {
142 use super::*;
143
144 fn candles(closes: &[f64]) -> Vec<Candle> {
145 closes
146 .iter()
147 .enumerate()
148 .map(|(i, &c)| Candle {
149 time: i64::try_from(i).expect("time index fits i64"),
150 open: c,
151 high: c,
152 low: c,
153 close: c,
154 volume: 1.0,
155 })
156 .collect()
157 }
158
159 #[test]
160 fn ema_insufficient_data() {
161 let ema = Ema::with_period(5);
162 assert!(ema.calculate(&candles(&[1.0, 2.0])).is_err());
163 }
164
165 #[test]
166 fn ema_output_column_named_correctly() {
167 let ema = Ema::with_period(3);
168 let out = ema.calculate(&candles(&[10.0, 20.0, 30.0])).unwrap();
169 assert!(out.get("EMA_3").is_some());
170 }
171
172 #[test]
173 fn ema_seed_equals_sma() {
174 let closes = vec![10.0, 20.0, 30.0];
176 let ema = Ema::with_period(3);
177 let out = ema.calculate(&candles(&closes)).unwrap();
178 let vals = out.get("EMA_3").unwrap();
179 let expected_seed = (10.0 + 20.0 + 30.0) / 3.0;
180 assert!((vals[2] - expected_seed).abs() < 1e-9, "got {}", vals[2]);
181 }
182
183 #[test]
184 fn ema_subsequent_value() {
185 let closes = vec![10.0, 20.0, 30.0, 40.0];
187 let ema = Ema::with_period(3);
188 let out = ema.calculate(&candles(&closes)).unwrap();
189 let vals = out.get("EMA_3").unwrap();
190 let expected = 40.0 * 0.5 + 20.0 * 0.5;
191 assert!((vals[3] - expected).abs() < 1e-6, "got {}", vals[3]);
192 }
193
194 #[test]
195 fn factory_creates_ema() {
196 let params = [("period".into(), "12".into())].into();
197 let ind = factory(¶ms).unwrap();
198 assert_eq!(ind.name(), "EMA");
199 }
200}