use crate::math::MulAdd;
macro_rules! impl_ema_float {
($name:ident, $builder:ident, $ty:ty) => {
#[derive(Debug, Clone)]
pub struct $name {
alpha: $ty,
one_minus_alpha: $ty,
value: $ty,
count: u64,
min_samples: u64,
}
#[doc = stringify!($name)]
#[derive(Debug, Clone)]
pub struct $builder {
alpha: Option<$ty>,
min_samples: u64,
seed: Option<$ty>,
}
impl $name {
#[inline]
#[must_use]
pub fn builder() -> $builder {
$builder {
alpha: Option::None,
min_samples: 1,
seed: Option::None,
}
}
#[inline]
#[must_use]
pub fn update(&mut self, sample: $ty) -> Option<$ty> {
self.count += 1;
if self.count == 1 {
self.value = sample;
} else {
self.value = self.alpha.fma(sample, self.one_minus_alpha * self.value);
}
if self.count >= self.min_samples {
Option::Some(self.value)
} else {
Option::None
}
}
#[inline]
#[must_use]
pub fn value(&self) -> Option<$ty> {
if self.count >= self.min_samples {
Option::Some(self.value)
} else {
Option::None
}
}
#[inline]
#[must_use]
pub fn alpha(&self) -> $ty {
self.alpha
}
#[inline]
#[must_use]
pub fn count(&self) -> u64 {
self.count
}
#[inline]
#[must_use]
pub fn is_primed(&self) -> bool {
self.count >= self.min_samples
}
#[inline]
pub fn reset(&mut self) {
self.value = 0.0 as $ty;
self.count = 0;
}
#[inline]
pub fn reconfigure_alpha(&mut self, alpha: $ty) -> Result<(), crate::ConfigError> {
if !(alpha > 0.0 as $ty && alpha < 1.0 as $ty) {
return Err(crate::ConfigError::Invalid("EMA alpha must be in (0, 1)"));
}
self.alpha = alpha;
self.one_minus_alpha = 1.0 as $ty - alpha;
Ok(())
}
}
impl $builder {
#[inline]
#[must_use]
pub fn alpha(mut self, alpha: $ty) -> Self {
self.alpha = Option::Some(alpha);
self
}
#[inline]
#[must_use]
#[cfg(any(feature = "std", feature = "libm"))]
pub fn halflife(mut self, halflife: $ty) -> Self {
let ln2 = core::f64::consts::LN_2 as $ty;
let alpha = 1.0 as $ty - crate::math::exp((-ln2 / halflife) as f64) as $ty;
self.alpha = Option::Some(alpha);
self
}
#[inline]
#[must_use]
pub fn span(mut self, n: u64) -> Self {
let alpha = 2.0 as $ty / (n as $ty + 1.0 as $ty);
self.alpha = Option::Some(alpha);
self
}
#[inline]
#[must_use]
pub fn min_samples(mut self, min: u64) -> Self {
self.min_samples = min;
self
}
#[inline]
#[must_use]
pub fn seed(mut self, value: $ty) -> Self {
self.seed = Option::Some(value);
self
}
#[inline]
pub fn build(self) -> Result<$name, crate::ConfigError> {
let alpha = self.alpha.ok_or(crate::ConfigError::Missing("alpha"))?;
if !(alpha > 0.0 as $ty && alpha < 1.0 as $ty) {
return Err(crate::ConfigError::Invalid("EMA alpha must be in (0, 1)"));
}
let (value, count) = if let Some(seed_val) = self.seed {
(seed_val, self.min_samples)
} else {
(0.0 as $ty, 0)
};
Ok($name {
alpha,
one_minus_alpha: 1.0 as $ty - alpha,
value,
count,
min_samples: self.min_samples,
})
}
}
};
}
impl_ema_float!(EmaF64, EmaF64Builder, f64);
impl_ema_float!(EmaF32, EmaF32Builder, f32);
#[inline]
pub(crate) const fn next_power_of_two_minus_one(n: u64) -> u64 {
if n == 0 {
return 0;
}
let v = n + 1;
let p = v.next_power_of_two();
p - 1
}
#[inline]
pub(crate) const fn log2_of_span_plus_one(span: u64) -> u32 {
(span + 1).trailing_zeros()
}
macro_rules! impl_ema_int {
($name:ident, $builder:ident, $ty:ty, $acc_ty:ty) => {
#[derive(Debug, Clone)]
pub struct $name {
/// Accumulator — stores value << shift for precision
acc: $acc_ty,
shift: u32,
span: u64,
count: u64,
min_samples: u64,
initialized: bool,
}
#[doc = stringify!($name)]
#[derive(Debug, Clone)]
pub struct $builder {
span: Option<u64>,
min_samples: u64,
seed: Option<$ty>,
}
impl $name {
#[inline]
#[must_use]
pub fn builder() -> $builder {
$builder {
span: Option::None,
min_samples: 1,
seed: Option::None,
}
}
#[inline]
#[must_use]
pub fn update(&mut self, sample: $ty) -> Option<$ty> {
self.count += 1;
if !self.initialized {
self.acc = (sample as $acc_ty) << self.shift;
self.initialized = true;
} else {
let sample_shifted = (sample as $acc_ty) << self.shift;
self.acc += (sample_shifted - self.acc) >> self.shift;
}
if self.count >= self.min_samples {
Option::Some((self.acc >> self.shift) as $ty)
} else {
Option::None
}
}
#[inline]
#[must_use]
pub fn value(&self) -> Option<$ty> {
if self.count >= self.min_samples && self.initialized {
Option::Some((self.acc >> self.shift) as $ty)
} else {
Option::None
}
}
#[inline]
#[must_use]
pub fn effective_span(&self) -> u64 {
self.span
}
#[inline]
#[must_use]
pub fn count(&self) -> u64 {
self.count
}
#[inline]
#[must_use]
pub fn is_primed(&self) -> bool {
self.count >= self.min_samples
}
#[inline]
pub fn reset(&mut self) {
self.acc = 0;
self.count = 0;
self.initialized = false;
}
#[inline]
pub fn reconfigure_span(&mut self, span: u64) -> Result<(), crate::ConfigError> {
if span < 1 {
return Err(crate::ConfigError::Invalid("EMA span must be >= 1"));
}
let effective = next_power_of_two_minus_one(span);
let new_shift = log2_of_span_plus_one(effective);
if self.initialized {
if new_shift > self.shift {
self.acc <<= new_shift - self.shift;
} else {
self.acc >>= self.shift - new_shift;
}
}
self.shift = new_shift;
self.span = effective;
Ok(())
}
}
impl $builder {
#[doc = stringify!($name)]
#[inline]
#[must_use]
pub fn span(mut self, n: u64) -> Self {
self.span = Option::Some(n);
self
}
#[inline]
#[must_use]
pub fn min_samples(mut self, min: u64) -> Self {
self.min_samples = min;
self
}
#[inline]
#[must_use]
pub fn seed(mut self, value: $ty) -> Self {
self.seed = Option::Some(value);
self
}
#[inline]
pub fn build(self) -> Result<$name, crate::ConfigError> {
let requested = self.span.ok_or(crate::ConfigError::Missing("span"))?;
if requested < 1 {
return Err(crate::ConfigError::Invalid("EMA span must be >= 1"));
}
let effective = next_power_of_two_minus_one(requested);
let shift = log2_of_span_plus_one(effective);
let (acc, count, initialized) = if let Some(seed_val) = self.seed {
((seed_val as $acc_ty) << shift, self.min_samples, true)
} else {
(0, 0, false)
};
Ok($name {
acc,
shift,
span: effective,
count,
min_samples: self.min_samples,
initialized,
})
}
}
};
}
impl_ema_int!(EmaI64, EmaI64Builder, i64, i128);
impl_ema_int!(EmaI32, EmaI32Builder, i32, i64);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn first_sample_initializes() {
let mut ema = EmaF64::builder().alpha(0.5).build().unwrap();
assert_eq!(ema.update(100.0), Some(100.0));
assert_eq!(ema.value(), Some(100.0));
}
#[test]
fn convergence_toward_constant() {
let mut ema = EmaF64::builder().alpha(0.1).build().unwrap();
let _ = ema.update(0.0);
for _ in 0..1000 {
let _ = ema.update(100.0);
}
let val = ema.value().unwrap();
assert!(
(val - 100.0).abs() < 0.01,
"EMA should converge to 100, got {val}"
);
}
#[test]
fn higher_alpha_reacts_faster() {
let mut fast = EmaF64::builder().alpha(0.9).build().unwrap();
let mut slow = EmaF64::builder().alpha(0.1).build().unwrap();
let _ = fast.update(0.0);
let _ = slow.update(0.0);
let _ = fast.update(100.0);
let _ = slow.update(100.0);
let fast_val = fast.value().unwrap();
let slow_val = slow.value().unwrap();
assert!(
fast_val > slow_val,
"fast ({fast_val}) should react more than slow ({slow_val})"
);
}
#[test]
fn priming_behavior() {
let mut ema = EmaF64::builder().alpha(0.5).min_samples(5).build().unwrap();
for i in 1..5 {
assert_eq!(ema.update(100.0), None, "sample {i} should not be primed");
assert!(!ema.is_primed());
}
assert!(ema.update(100.0).is_some());
assert!(ema.is_primed());
}
#[test]
fn reset_clears_state() {
let mut ema = EmaF64::builder().alpha(0.5).build().unwrap();
let _ = ema.update(100.0);
let _ = ema.update(200.0);
ema.reset();
assert_eq!(ema.count(), 0);
assert_eq!(ema.value(), None);
assert_eq!(ema.update(50.0), Some(50.0));
}
#[test]
fn span_computes_alpha() {
let ema = EmaF64::builder().span(19).build().unwrap();
assert!((ema.alpha() - 0.1).abs() < 1e-10);
}
#[test]
fn halflife_computes_alpha() {
let ema = EmaF64::builder().halflife(1.0).build().unwrap();
assert!((ema.alpha() - 0.5).abs() < 1e-10);
}
#[test]
fn errors_without_alpha() {
let result = EmaF64::builder().build();
assert!(matches!(result, Err(crate::ConfigError::Missing("alpha"))));
}
#[test]
fn errors_on_alpha_zero() {
let result = EmaF64::builder().alpha(0.0).build();
assert!(matches!(result, Err(crate::ConfigError::Invalid(_))));
}
#[test]
fn errors_on_alpha_one() {
let result = EmaF64::builder().alpha(1.0).build();
assert!(matches!(result, Err(crate::ConfigError::Invalid(_))));
}
#[test]
fn f32_basic() {
let mut ema = EmaF32::builder().alpha(0.5).build().unwrap();
assert_eq!(ema.update(100.0), Some(100.0));
let v = ema.update(200.0).unwrap();
assert!((v - 150.0).abs() < 0.01);
}
#[test]
fn span_rounding() {
let ema = EmaI64::builder().span(1).build().unwrap();
assert_eq!(ema.effective_span(), 1);
let ema = EmaI64::builder().span(2).build().unwrap();
assert_eq!(ema.effective_span(), 3);
let ema = EmaI64::builder().span(3).build().unwrap();
assert_eq!(ema.effective_span(), 3);
let ema = EmaI64::builder().span(7).build().unwrap();
assert_eq!(ema.effective_span(), 7);
let ema = EmaI64::builder().span(10).build().unwrap();
assert_eq!(ema.effective_span(), 15);
let ema = EmaI64::builder().span(20).build().unwrap();
assert_eq!(ema.effective_span(), 31);
}
#[test]
fn int_first_sample_initializes() {
let mut ema = EmaI64::builder().span(7).build().unwrap();
assert_eq!(ema.update(1000), Some(1000));
}
#[test]
fn int_convergence() {
let mut ema = EmaI64::builder().span(7).build().unwrap();
let _ = ema.update(0);
for _ in 0..10_000 {
let _ = ema.update(1000);
}
let val = ema.value().unwrap();
assert!(
(val - 1000).abs() <= 1,
"should converge to ~1000, got {val}"
);
}
#[test]
fn int_no_drift_over_many_samples() {
let mut ema = EmaI64::builder().span(15).build().unwrap();
for _ in 0..100_000 {
let _ = ema.update(500);
}
let val = ema.value().unwrap();
assert_eq!(
val, 500,
"constant input should produce exact output, got {val}"
);
}
#[test]
fn int_priming() {
let mut ema = EmaI64::builder().span(7).min_samples(5).build().unwrap();
for _ in 0..4 {
assert_eq!(ema.update(100), None);
}
assert!(ema.update(100).is_some());
}
#[test]
fn int_reset() {
let mut ema = EmaI64::builder().span(7).build().unwrap();
let _ = ema.update(1000);
let _ = ema.update(2000);
ema.reset();
assert_eq!(ema.count(), 0);
assert_eq!(ema.value(), None);
}
#[test]
fn i32_basic() {
let mut ema = EmaI32::builder().span(3).build().unwrap();
assert_eq!(ema.update(100), Some(100));
}
#[test]
fn int_errors_without_span() {
let result = EmaI64::builder().build();
assert!(matches!(result, Err(crate::ConfigError::Missing("span"))));
}
#[test]
#[allow(clippy::float_cmp)]
fn float_reconfigure_alpha_preserves_value() {
let mut ema = EmaF64::builder().alpha(0.5).build().unwrap();
let _ = ema.update(100.0);
let _ = ema.update(200.0);
let val_before = ema.value().unwrap();
let count_before = ema.count();
ema.reconfigure_alpha(0.9).unwrap();
assert!((ema.alpha() - 0.9).abs() < 1e-10);
assert_eq!(ema.value().unwrap(), val_before);
assert_eq!(ema.count(), count_before);
}
#[test]
fn float_reconfigure_alpha_validates() {
let mut ema = EmaF64::builder().alpha(0.5).build().unwrap();
assert!(ema.reconfigure_alpha(0.0).is_err());
assert!(ema.reconfigure_alpha(1.0).is_err());
assert!(ema.reconfigure_alpha(-0.1).is_err());
}
#[test]
fn int_reconfigure_span_preserves_value() {
let mut ema = EmaI64::builder().span(7).build().unwrap();
for _ in 0..100 {
let _ = ema.update(500);
}
let val_before = ema.value().unwrap();
let count_before = ema.count();
ema.reconfigure_span(15).unwrap();
assert_eq!(ema.effective_span(), 15);
assert_eq!(ema.value().unwrap(), val_before);
assert_eq!(ema.count(), count_before);
}
#[test]
fn int_reconfigure_span_validates() {
let mut ema = EmaI64::builder().span(7).build().unwrap();
assert!(ema.reconfigure_span(0).is_err());
}
#[test]
fn int_vs_float_comparison() {
let mut int_ema = EmaI64::builder().span(15).build().unwrap();
let mut float_ema = EmaF64::builder().span(15).build().unwrap();
let samples = [100, 110, 95, 105, 120, 90, 100, 115, 85, 100];
for &s in &samples {
let _ = int_ema.update(s);
let _ = float_ema.update(s as f64);
}
let int_val = int_ema.value().unwrap();
let float_val = float_ema.value().unwrap();
let diff = (int_val as f64 - float_val).abs();
assert!(
diff < 5.0,
"int ({int_val}) and float ({float_val}) should be close, diff={diff}"
);
}
}