wickra-core 0.1.2

Core streaming-first technical indicators engine for the Wickra library
//! Exponential Moving Average.

use crate::error::{Error, Result};
use crate::traits::Indicator;

/// Exponential Moving Average with smoothing factor `alpha = 2 / (period + 1)`.
///
/// The first value is seeded with the simple mean of the first `period` inputs
/// (the classical TA-Lib convention). From then on each new input contributes
/// `alpha * input + (1 - alpha) * previous`.
#[derive(Debug, Clone)]
pub struct Ema {
    period: usize,
    alpha: f64,
    state: Option<f64>,
    warmup_buf: Vec<f64>,
}

impl Ema {
    /// Construct an EMA with the given period.
    ///
    /// # Errors
    ///
    /// Returns [`Error::PeriodZero`] if `period == 0`.
    pub fn new(period: usize) -> Result<Self> {
        if period == 0 {
            return Err(Error::PeriodZero);
        }
        let alpha = 2.0 / (period as f64 + 1.0);
        Ok(Self {
            period,
            alpha,
            state: None,
            warmup_buf: Vec::with_capacity(period),
        })
    }

    /// Construct an EMA with a custom smoothing factor `alpha in (0, 1]`.
    ///
    /// The reported `period` is derived from `alpha` via `2/alpha - 1` and rounded;
    /// `warmup_period()` falls back to `1` because the implementation seeds from the
    /// very first input.
    ///
    /// # Errors
    ///
    /// Returns [`Error::InvalidPeriod`] if `alpha` is not in `(0.0, 1.0]` or non-finite.
    pub fn with_alpha(alpha: f64) -> Result<Self> {
        if !alpha.is_finite() || alpha <= 0.0 || alpha > 1.0 {
            return Err(Error::InvalidPeriod {
                message: "alpha must be in (0.0, 1.0]",
            });
        }
        Ok(Self {
            period: 1,
            alpha,
            state: None,
            warmup_buf: Vec::with_capacity(1),
        })
    }

    /// Configured period.
    pub const fn period(&self) -> usize {
        self.period
    }

    /// Smoothing factor.
    pub const fn alpha(&self) -> f64 {
        self.alpha
    }

    /// Current value if available.
    pub const fn value(&self) -> Option<f64> {
        self.state
    }

    /// Internal helper that feeds a value without finiteness validation. The caller
    /// guarantees `input.is_finite()`. Used by MACD which has already validated.
    pub(crate) fn step_unchecked(&mut self, input: f64) -> Option<f64> {
        if let Some(prev) = self.state {
            let new = self.alpha.mul_add(input, (1.0 - self.alpha) * prev);
            self.state = Some(new);
            return Some(new);
        }
        self.warmup_buf.push(input);
        if self.warmup_buf.len() == self.period {
            let seed = self.warmup_buf.iter().copied().sum::<f64>() / self.period as f64;
            self.state = Some(seed);
            return Some(seed);
        }
        None
    }
}

impl Indicator for Ema {
    type Input = f64;
    type Output = f64;

    fn update(&mut self, input: f64) -> Option<f64> {
        if !input.is_finite() {
            return self.state;
        }
        self.step_unchecked(input)
    }

    fn reset(&mut self) {
        self.state = None;
        self.warmup_buf.clear();
    }

    fn warmup_period(&self) -> usize {
        self.period
    }

    fn is_ready(&self) -> bool {
        self.state.is_some()
    }

    fn name(&self) -> &'static str {
        "EMA"
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::traits::BatchExt;
    use approx::assert_relative_eq;

    #[test]
    fn new_rejects_zero_period() {
        assert!(matches!(Ema::new(0), Err(Error::PeriodZero)));
    }

    #[test]
    fn warmup_returns_none_until_seed() {
        let mut ema = Ema::new(3).unwrap();
        assert_eq!(ema.update(1.0), None);
        assert_eq!(ema.update(2.0), None);
        assert_eq!(ema.update(3.0), Some(2.0)); // seed = SMA([1,2,3]) = 2
    }

    #[test]
    fn first_value_equals_sma_seed() {
        let mut ema = Ema::new(5).unwrap();
        let inputs = [10.0, 20.0, 30.0, 40.0, 50.0];
        let mut last = None;
        for v in inputs {
            last = ema.update(v);
        }
        assert_relative_eq!(last.unwrap(), 30.0, epsilon = 1e-12);
    }

    #[test]
    fn alpha_matches_period_formula() {
        let ema = Ema::new(10).unwrap();
        assert_relative_eq!(ema.alpha(), 2.0 / 11.0, epsilon = 1e-15);
    }

    #[test]
    fn step_after_seed_uses_alpha_formula() {
        // period=3 => alpha = 0.5; seed = mean([1,2,3]) = 2; next input 10
        // expected = 0.5*10 + 0.5*2 = 6
        let mut ema = Ema::new(3).unwrap();
        ema.batch(&[1.0, 2.0, 3.0]);
        assert_relative_eq!(ema.update(10.0).unwrap(), 6.0, epsilon = 1e-12);
    }

    #[test]
    fn constant_series_converges_to_constant() {
        let mut ema = Ema::new(10).unwrap();
        let out = ema.batch(&[42.0_f64; 100]);
        for x in out.iter().skip(9) {
            assert_relative_eq!(x.unwrap(), 42.0, epsilon = 1e-9);
        }
    }

    #[test]
    fn with_alpha_validates_range() {
        assert!(Ema::with_alpha(0.5).is_ok());
        assert!(Ema::with_alpha(1.0).is_ok());
        assert!(matches!(
            Ema::with_alpha(0.0),
            Err(Error::InvalidPeriod { .. })
        ));
        assert!(matches!(
            Ema::with_alpha(1.5),
            Err(Error::InvalidPeriod { .. })
        ));
        assert!(matches!(
            Ema::with_alpha(f64::NAN),
            Err(Error::InvalidPeriod { .. })
        ));
    }

    #[test]
    fn reset_clears_state() {
        let mut ema = Ema::new(3).unwrap();
        ema.batch(&[1.0, 2.0, 3.0]);
        assert!(ema.is_ready());
        ema.reset();
        assert!(!ema.is_ready());
        assert_eq!(ema.update(1.0), None);
    }

    #[test]
    fn batch_equals_streaming() {
        let prices: Vec<f64> = (1..=30).map(f64::from).collect();
        let mut a = Ema::new(5).unwrap();
        let mut b = Ema::new(5).unwrap();
        assert_eq!(
            a.batch(&prices),
            prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
        );
    }

    #[test]
    fn ignores_non_finite_input() {
        let mut ema = Ema::new(3).unwrap();
        ema.batch(&[1.0, 2.0, 3.0]);
        let before = ema.value();
        assert_eq!(ema.update(f64::NAN), before);
        assert_eq!(ema.update(f64::INFINITY), before);
    }
}