Skip to main content

indicators/trend/
ema.rs

1//! Exponential Moving Average (EMA).
2//!
3//! Python source: `indicators/trend/moving_average.py :: class EMA`
4//!
5//! # Python algorithm (to port)
6//! ```python
7//! ema = data[self.column].ewm(span=self.period, adjust=False, alpha=self.alpha).mean()
8//! ```
9//!
10//! Note: `self.alpha = params.get("alpha", 2 / (period + 1))`
11//!
12//! Output column: `"EMA_{period}"`.
13//!
14//! See also: `crate::functions::ema()` for the existing batch implementation
15//! and `crate::functions::EMA` for the existing incremental struct — both
16//! can serve as the porting target for the `calculate()` body here.
17
18use std::collections::HashMap;
19
20use crate::error::IndicatorError;
21use crate::functions::{self};
22use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
23use crate::registry::{param_f64, param_str, param_usize};
24use crate::types::Candle;
25
26// ── Params ────────────────────────────────────────────────────────────────────
27
28#[derive(Debug, Clone)]
29pub struct EmaParams {
30    /// Lookback period (span).  Python default: 20.
31    pub period: usize,
32    /// Smoothing factor.  Python default: `2 / (period + 1)`.
33    /// Pass `None` to use the standard formula.
34    pub alpha: Option<f64>,
35    /// Price field.  Python default: `"close"`.
36    pub column: PriceColumn,
37}
38
39impl Default for EmaParams {
40    fn default() -> Self {
41        Self {
42            period: 20,
43            alpha: None,
44            column: PriceColumn::Close,
45        }
46    }
47}
48
49impl EmaParams {
50    fn effective_alpha(&self) -> f64 {
51        self.alpha
52            .unwrap_or_else(|| 2.0 / (self.period as f64 + 1.0))
53    }
54}
55
56// ── Indicator struct ──────────────────────────────────────────────────────────
57
58#[derive(Debug, Clone)]
59pub struct Ema {
60    pub params: EmaParams,
61}
62
63impl Ema {
64    pub fn new(params: EmaParams) -> Self {
65        Self { params }
66    }
67
68    pub fn with_period(period: usize) -> Self {
69        Self::new(EmaParams {
70            period,
71            ..Default::default()
72        })
73    }
74
75    fn output_key(&self) -> String {
76        format!("EMA_{}", self.params.period)
77    }
78}
79
80impl Indicator for Ema {
81    fn name(&self) -> &'static str {
82        "EMA"
83    }
84
85    fn required_len(&self) -> usize {
86        self.params.period
87    }
88
89    fn required_columns(&self) -> &[&'static str] {
90        &["close"]
91    }
92
93    /// Ports Python `ewm(span=period, adjust=False, alpha=alpha).mean()`.
94    ///
95    /// Delegates to `crate::functions::ema()` when alpha matches the standard
96    /// `2/(period+1)` formula, and uses a local SMA-seeded EMA loop when a
97    /// custom alpha is provided.
98    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
99        self.check_len(candles)?;
100
101        let prices = self.params.column.extract(candles);
102        let alpha = self.params.effective_alpha();
103        let period = self.params.period;
104        let default_alpha = 2.0 / (period as f64 + 1.0);
105
106        let values = if (alpha - default_alpha).abs() < f64::EPSILON {
107            // Fast path: delegate to the shared batch implementation.
108            functions::ema(&prices, period)?
109        } else {
110            // Custom alpha path: SMA-seed then apply caller-supplied smoothing factor.
111            ema_with_alpha(&prices, period, alpha)?
112        };
113
114        Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
115    }
116}
117
118// ── Helpers ───────────────────────────────────────────────────────────────────
119
120/// SMA-seeded EMA with a caller-supplied smoothing factor.
121///
122/// Mirrors Python `series.ewm(span=period, adjust=False, alpha=alpha).mean()`.
123/// The seed (index `period-1`) is the arithmetic mean of the first `period`
124/// values; subsequent values follow `ema[i] = alpha * price[i] + (1-alpha) * ema[i-1]`.
125fn ema_with_alpha(prices: &[f64], period: usize, alpha: f64) -> Result<Vec<f64>, IndicatorError> {
126    if prices.len() < period {
127        return Err(IndicatorError::InsufficientData {
128            required: period,
129            available: prices.len(),
130        });
131    }
132    let mut result = vec![f64::NAN; prices.len()];
133    let seed: f64 = prices.iter().take(period).sum::<f64>() / period as f64;
134    result[period - 1] = seed;
135    let one_minus = 1.0 - alpha;
136    for i in period..prices.len() {
137        result[i] = prices[i] * alpha + result[i - 1] * one_minus;
138    }
139    Ok(result)
140}
141
142// ── Registry factory ──────────────────────────────────────────────────────────
143
144pub fn factory<S: ::std::hash::BuildHasher>(
145    params: &HashMap<String, String, S>,
146) -> Result<Box<dyn Indicator>, IndicatorError> {
147    let period = param_usize(params, "period", 20)?;
148    let alpha = if params.contains_key("alpha") {
149        Some(param_f64(params, "alpha", 2.0 / (period as f64 + 1.0))?)
150    } else {
151        None
152    };
153    let column = match param_str(params, "column", "close") {
154        "open" => PriceColumn::Open,
155        "high" => PriceColumn::High,
156        "low" => PriceColumn::Low,
157        _ => PriceColumn::Close,
158    };
159    Ok(Box::new(Ema::new(EmaParams {
160        period,
161        alpha,
162        column,
163    })))
164}
165
166// ── Tests ─────────────────────────────────────────────────────────────────────
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    fn candles(closes: &[f64]) -> Vec<Candle> {
173        closes
174            .iter()
175            .enumerate()
176            .map(|(i, &c)| Candle {
177                time: i64::try_from(i).expect("time index fits i64"),
178                open: c,
179                high: c,
180                low: c,
181                close: c,
182                volume: 1.0,
183            })
184            .collect()
185    }
186
187    #[test]
188    fn ema_insufficient_data() {
189        let ema = Ema::with_period(5);
190        assert!(ema.calculate(&candles(&[1.0, 2.0])).is_err());
191    }
192
193    #[test]
194    fn ema_output_column_named_correctly() {
195        let ema = Ema::with_period(3);
196        let out = ema.calculate(&candles(&[10.0, 20.0, 30.0])).unwrap();
197        assert!(out.get("EMA_3").is_some());
198    }
199
200    #[test]
201    fn ema_seed_equals_sma() {
202        // EMA at index `period-1` should equal the SMA of first `period` values.
203        let closes = vec![10.0, 20.0, 30.0];
204        let ema = Ema::with_period(3);
205        let out = ema.calculate(&candles(&closes)).unwrap();
206        let vals = out.get("EMA_3").unwrap();
207        let expected_seed = (10.0 + 20.0 + 30.0) / 3.0;
208        assert!((vals[2] - expected_seed).abs() < 1e-9, "got {}", vals[2]);
209    }
210
211    #[test]
212    fn ema_subsequent_value() {
213        // alpha = 2/(3+1) = 0.5; EMA[3] = 40*0.5 + 20*0.5 = 30
214        let closes = vec![10.0, 20.0, 30.0, 40.0];
215        let ema = Ema::with_period(3);
216        let out = ema.calculate(&candles(&closes)).unwrap();
217        let vals = out.get("EMA_3").unwrap();
218        let expected = 40.0 * 0.5 + 20.0 * 0.5;
219        assert!((vals[3] - expected).abs() < 1e-6, "got {}", vals[3]);
220    }
221
222    #[test]
223    fn ema_custom_alpha_differs_from_default() {
224        // alpha=0.1 ≠ 2/(3+1)=0.5; after the seed the two paths diverge.
225        let closes = vec![10.0, 20.0, 30.0, 40.0];
226        let default_out = Ema::with_period(3).calculate(&candles(&closes)).unwrap();
227        let custom_out = Ema::new(EmaParams {
228            period: 3,
229            alpha: Some(0.1),
230            column: PriceColumn::Close,
231        })
232        .calculate(&candles(&closes))
233        .unwrap();
234        let d = default_out.get("EMA_3").unwrap();
235        let c = custom_out.get("EMA_3").unwrap();
236        // Seed (index 2) is the same SMA regardless of alpha.
237        assert!((c[2] - d[2]).abs() < 1e-9);
238        // After seed, alpha=0.1 must produce a different value than alpha=0.5.
239        assert!(
240            (c[3] - d[3]).abs() > 1e-6,
241            "custom alpha should differ: {}",
242            c[3]
243        );
244    }
245
246    #[test]
247    fn ema_custom_alpha_correct_value() {
248        // Seed = (10+20+30)/3 = 20; EMA[3] = 40*0.1 + 20*0.9 = 22.0
249        let closes = vec![10.0, 20.0, 30.0, 40.0];
250        let ema = Ema::new(EmaParams {
251            period: 3,
252            alpha: Some(0.1),
253            column: PriceColumn::Close,
254        });
255        let out = ema.calculate(&candles(&closes)).unwrap();
256        let vals = out.get("EMA_3").unwrap();
257        assert!((vals[3] - 22.0).abs() < 1e-9, "got {}", vals[3]);
258    }
259
260    #[test]
261    fn factory_creates_ema() {
262        let params = [("period".into(), "12".into())].into();
263        let ind = factory(&params).unwrap();
264        assert_eq!(ind.name(), "EMA");
265    }
266}