use crate::error::{Error, Result};
use crate::traits::Indicator;
#[derive(Debug, Clone)]
pub struct Ema {
period: usize,
alpha: f64,
state: Option<f64>,
warmup_buf: Vec<f64>,
}
impl Ema {
pub fn new(period: usize) -> Result<Self> {
if period == 0 {
return Err(Error::PeriodZero);
}
let alpha = 2.0 / (period as f64 + 1.0);
Ok(Self {
period,
alpha,
state: None,
warmup_buf: Vec::with_capacity(period),
})
}
pub fn with_alpha(alpha: f64) -> Result<Self> {
if !alpha.is_finite() || alpha <= 0.0 || alpha > 1.0 {
return Err(Error::InvalidPeriod {
message: "alpha must be in (0.0, 1.0]",
});
}
Ok(Self {
period: 1,
alpha,
state: None,
warmup_buf: Vec::with_capacity(1),
})
}
pub const fn period(&self) -> usize {
self.period
}
pub const fn alpha(&self) -> f64 {
self.alpha
}
pub const fn value(&self) -> Option<f64> {
self.state
}
pub(crate) fn step_unchecked(&mut self, input: f64) -> Option<f64> {
if let Some(prev) = self.state {
let new = self.alpha.mul_add(input, (1.0 - self.alpha) * prev);
self.state = Some(new);
return Some(new);
}
self.warmup_buf.push(input);
if self.warmup_buf.len() == self.period {
let seed = self.warmup_buf.iter().copied().sum::<f64>() / self.period as f64;
self.state = Some(seed);
return Some(seed);
}
None
}
}
impl Indicator for Ema {
type Input = f64;
type Output = f64;
fn update(&mut self, input: f64) -> Option<f64> {
if !input.is_finite() {
return self.state;
}
self.step_unchecked(input)
}
fn reset(&mut self) {
self.state = None;
self.warmup_buf.clear();
}
fn warmup_period(&self) -> usize {
self.period
}
fn is_ready(&self) -> bool {
self.state.is_some()
}
fn name(&self) -> &'static str {
"EMA"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::BatchExt;
use approx::assert_relative_eq;
#[test]
fn new_rejects_zero_period() {
assert!(matches!(Ema::new(0), Err(Error::PeriodZero)));
}
#[test]
fn warmup_returns_none_until_seed() {
let mut ema = Ema::new(3).unwrap();
assert_eq!(ema.update(1.0), None);
assert_eq!(ema.update(2.0), None);
assert_eq!(ema.update(3.0), Some(2.0)); }
#[test]
fn first_value_equals_sma_seed() {
let mut ema = Ema::new(5).unwrap();
let inputs = [10.0, 20.0, 30.0, 40.0, 50.0];
let mut last = None;
for v in inputs {
last = ema.update(v);
}
assert_relative_eq!(last.unwrap(), 30.0, epsilon = 1e-12);
}
#[test]
fn alpha_matches_period_formula() {
let ema = Ema::new(10).unwrap();
assert_relative_eq!(ema.alpha(), 2.0 / 11.0, epsilon = 1e-15);
}
#[test]
fn step_after_seed_uses_alpha_formula() {
let mut ema = Ema::new(3).unwrap();
ema.batch(&[1.0, 2.0, 3.0]);
assert_relative_eq!(ema.update(10.0).unwrap(), 6.0, epsilon = 1e-12);
}
#[test]
fn constant_series_converges_to_constant() {
let mut ema = Ema::new(10).unwrap();
let out = ema.batch(&[42.0_f64; 100]);
for x in out.iter().skip(9) {
assert_relative_eq!(x.unwrap(), 42.0, epsilon = 1e-9);
}
}
#[test]
fn with_alpha_validates_range() {
assert!(Ema::with_alpha(0.5).is_ok());
assert!(Ema::with_alpha(1.0).is_ok());
assert!(matches!(
Ema::with_alpha(0.0),
Err(Error::InvalidPeriod { .. })
));
assert!(matches!(
Ema::with_alpha(1.5),
Err(Error::InvalidPeriod { .. })
));
assert!(matches!(
Ema::with_alpha(f64::NAN),
Err(Error::InvalidPeriod { .. })
));
}
#[test]
fn reset_clears_state() {
let mut ema = Ema::new(3).unwrap();
ema.batch(&[1.0, 2.0, 3.0]);
assert!(ema.is_ready());
ema.reset();
assert!(!ema.is_ready());
assert_eq!(ema.update(1.0), None);
}
#[test]
fn batch_equals_streaming() {
let prices: Vec<f64> = (1..=30).map(f64::from).collect();
let mut a = Ema::new(5).unwrap();
let mut b = Ema::new(5).unwrap();
assert_eq!(
a.batch(&prices),
prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
);
}
#[test]
fn ignores_non_finite_input() {
let mut ema = Ema::new(3).unwrap();
ema.batch(&[1.0, 2.0, 3.0]);
let before = ema.value();
assert_eq!(ema.update(f64::NAN), before);
assert_eq!(ema.update(f64::INFINITY), before);
}
}