use crate::bar_indicators::indicator_value::IndicatorValue;
use crate::bar_indicators::ohlcv_field::OhlcvField;
#[derive(Debug, Clone)]
pub struct Ema {
period: usize,
alpha: f64,
source: OhlcvField,
value: f64,
count: usize,
}
impl Ema {
pub fn period(&self) -> usize {
self.period
}
pub fn new(period: usize) -> Self {
Self::with_source(period, OhlcvField::Close)
}
pub fn with_source(period: usize, source: OhlcvField) -> Self {
let alpha = 2.0 / (period as f64 + 1.0);
Self {
period,
alpha,
source,
value: 0.0,
count: 0,
}
}
pub fn update_bar(&mut self, open: f64, high: f64, low: f64, close: f64, volume: f64) -> f64 {
let value = self.source.extract(open, high, low, close, volume);
if self.count == 0 {
self.value = value;
} else {
self.value = self.alpha * value + (1.0 - self.alpha) * self.value;
}
self.count += 1;
self.value
}
pub fn value(&self) -> IndicatorValue {
IndicatorValue::Single(self.value)
}
pub fn value_f64(&self) -> f64 {
self.value
}
pub fn is_ready(&self) -> bool {
self.count >= self.period
}
pub fn reset(&mut self) {
self.value = 0.0;
self.count = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ema_basic_calculation() {
let mut ema = Ema::new(3);
let v1 = ema.update_bar(0.0, 0.0, 0.0, 10.0, 0.0);
assert!((v1 - 10.0).abs() < 1e-10);
let v2 = ema.update_bar(0.0, 0.0, 0.0, 20.0, 0.0);
assert!((v2 - 15.0).abs() < 1e-10);
let v3 = ema.update_bar(0.0, 0.0, 0.0, 30.0, 0.0);
assert!(ema.is_ready());
assert!((v3 - 22.5).abs() < 1e-10);
}
#[test]
fn test_ema_alpha_calculation() {
let ema = Ema::new(9);
assert!((ema.alpha - 0.2).abs() < 1e-10);
let ema = Ema::new(19);
assert!((ema.alpha - 0.1).abs() < 1e-10);
}
#[test]
fn test_ema_reset() {
let mut ema = Ema::new(3);
ema.update_bar(0.0, 0.0, 0.0, 10.0, 0.0);
ema.update_bar(0.0, 0.0, 0.0, 20.0, 0.0);
ema.update_bar(0.0, 0.0, 0.0, 30.0, 0.0);
assert!(ema.is_ready());
ema.reset();
assert!(!ema.is_ready());
assert!((ema.value_f64()).abs() < 1e-10);
}
#[test]
fn test_ema_value_types() {
let mut ema = Ema::new(2);
ema.update_bar(0.0, 0.0, 0.0, 10.0, 0.0);
ema.update_bar(0.0, 0.0, 0.0, 20.0, 0.0);
let indicator_val = ema.value();
let f64_val = ema.value_f64();
if let IndicatorValue::Single(v) = indicator_val {
assert!((v - f64_val).abs() < 1e-10);
} else {
panic!("Expected Single variant");
}
}
#[test]
fn test_ema_with_different_sources() {
let bars = vec![
(100.0, 110.0, 90.0, 105.0, 1000.0), (105.0, 115.0, 95.0, 110.0, 1200.0),
(110.0, 120.0, 100.0, 115.0, 800.0),
];
let mut ema_close = Ema::new(3);
for (o, h, l, c, v) in &bars {
ema_close.update_bar(*o, *h, *l, *c, *v);
}
assert!((ema_close.value_f64() - 111.25).abs() < 1e-10);
let mut ema_hl2 = Ema::with_source(3, OhlcvField::HL2);
for (o, h, l, c, v) in &bars {
ema_hl2.update_bar(*o, *h, *l, *c, *v);
}
assert!((ema_hl2.value_f64() - 106.25).abs() < 1e-10);
let mut ema_open = Ema::with_source(3, OhlcvField::Open);
for (o, h, l, c, v) in &bars {
ema_open.update_bar(*o, *h, *l, *c, *v);
}
assert!((ema_open.value_f64() - 106.25).abs() < 1e-10);
let mut ema_hlc3 = Ema::with_source(3, OhlcvField::HLC3);
for (o, h, l, c, v) in &bars {
ema_hlc3.update_bar(*o, *h, *l, *c, *v);
}
let hlc3_1 = (110.0 + 90.0 + 105.0) / 3.0;
let hlc3_2 = (115.0 + 95.0 + 110.0) / 3.0;
let hlc3_3 = (120.0 + 100.0 + 115.0) / 3.0;
let expected = 0.5 * hlc3_3 + 0.5 * (0.5 * hlc3_2 + 0.5 * hlc3_1);
assert!((ema_hlc3.value_f64() - expected).abs() < 1e-10);
}
}