Skip to main content

quantwave_core/indicators/incremental/
talib_ema.rs

1//! TA-Lib compatible EMA state machine (SMA seed, NaN until lookback).
2
3use crate::traits::Next;
4
5/// Exponential moving average matching `talib_rs::overlap::ema`.
6#[derive(Debug, Clone)]
7pub struct TalibEma {
8    period: usize,
9    k: f64,
10    lookback: usize,
11    bars_seen: usize,
12    seed_sum: f64,
13    value: f64,
14}
15
16impl TalibEma {
17    pub fn new(period: usize) -> Self {
18        Self {
19            period,
20            k: 2.0 / (period as f64 + 1.0),
21            lookback: period.saturating_sub(1),
22            bars_seen: 0,
23            seed_sum: 0.0,
24            value: f64::NAN,
25        }
26    }
27}
28
29impl Next<f64> for TalibEma {
30    type Output = f64;
31
32    fn next(&mut self, input: f64) -> Self::Output {
33        if self.period == 0 || input.is_nan() {
34            return f64::NAN;
35        }
36        let i = self.bars_seen;
37        self.bars_seen += 1;
38
39        if i < self.lookback {
40            self.seed_sum += input;
41            return f64::NAN;
42        }
43
44        if i == self.lookback {
45            self.seed_sum += input;
46            self.value = self.seed_sum / self.period as f64;
47            return self.value;
48        }
49
50        self.value = self.k.mul_add(input - self.value, self.value);
51        self.value
52    }
53}
54
55#[cfg(test)]
56mod tests {
57    use super::*;
58    use proptest::prelude::*;
59
60    proptest! {
61        #[test]
62        fn test_talib_ema_parity(input in prop::collection::vec(0.1..100.0, 1..100)) {
63            let period = 10;
64            let mut ema = TalibEma::new(period);
65            let streaming: Vec<f64> = input.iter().map(|&x| ema.next(x)).collect();
66            let batch = talib_rs::overlap::ema(&input, period)
67                .unwrap_or_else(|_| vec![f64::NAN; input.len()]);
68            for (s, b) in streaming.iter().zip(batch.iter()) {
69                if s.is_nan() {
70                    assert!(b.is_nan());
71                } else {
72                    approx::assert_relative_eq!(s, b, epsilon = 1e-6);
73                }
74            }
75        }
76    }
77}