use crate::indicators::utils::calculate_ema;
use crate::indicators::validate_period;
use crate::indicators::{Candle, Indicator, IndicatorError};
#[derive(Debug)]
pub struct Ema {
period: usize,
alpha: f64,
current_ema: Option<f64>,
}
impl Ema {
pub fn new(period: usize) -> Result<Self, IndicatorError> {
validate_period(period, 1)?;
Ok(Self {
period,
alpha: 2.0 / (period as f64 + 1.0),
current_ema: None,
})
}
pub fn with_initial_value(&mut self, value: f64) -> &mut Self {
self.current_ema = Some(value);
self
}
pub fn reset_state(&mut self) {
self.current_ema = None;
}
}
impl Indicator<f64, f64> for Ema {
fn calculate(&mut self, data: &[f64]) -> Result<Vec<f64>, IndicatorError> {
calculate_ema(data, self.period)
}
fn next(&mut self, value: f64) -> Result<Option<f64>, IndicatorError> {
if let Some(current) = self.current_ema {
let new_ema = (value * self.alpha) + (current * (1.0 - self.alpha));
self.current_ema = Some(new_ema);
Ok(Some(new_ema))
} else {
self.current_ema = Some(value);
Ok(Some(value))
}
}
fn reset(&mut self) {
self.reset_state();
}
}
impl Indicator<Candle, f64> for Ema {
fn calculate(&mut self, data: &[Candle]) -> Result<Vec<f64>, IndicatorError> {
let close_prices: Vec<f64> = data.iter().map(|candle| candle.close).collect();
calculate_ema(&close_prices, self.period)
}
fn next(&mut self, candle: Candle) -> Result<Option<f64>, IndicatorError> {
let close_price = candle.close;
if let Some(current) = self.current_ema {
let new_ema = (close_price * self.alpha) + (current * (1.0 - self.alpha));
self.current_ema = Some(new_ema);
Ok(Some(new_ema))
} else {
self.current_ema = Some(close_price);
Ok(Some(close_price))
}
}
fn reset(&mut self) {
self.current_ema = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ema_new() {
assert!(Ema::new(14).is_ok());
assert!(Ema::new(0).is_err());
}
#[test]
fn test_ema_calculation() {
let mut ema = Ema::new(3).unwrap();
let data = vec![2.0, 4.0, 6.0, 8.0, 10.0];
let result = ema.calculate(&data).unwrap();
assert_eq!(result.len(), data.len());
let alpha = 0.5;
let mut expected = vec![data[0]];
for &v in &data[1..] {
let prev = *expected.last().unwrap();
expected.push((v - prev) * alpha + prev);
}
for (i, (&got, &want)) in result.iter().zip(expected.iter()).enumerate() {
assert!(
(got - want).abs() < 1e-12,
"row {i}: got {got}, expected {want}"
);
}
}
#[test]
fn test_ema_next() {
let mut ema = Ema::new(3).unwrap();
let alpha = 0.5;
assert_eq!(ema.next(2.0).unwrap(), Some(2.0));
let expected1 = 4.0 * alpha + 2.0 * (1.0 - alpha); assert_eq!(ema.next(4.0).unwrap(), Some(expected1));
let expected2 = 6.0 * alpha + expected1 * (1.0 - alpha); assert_eq!(ema.next(6.0).unwrap(), Some(expected2));
}
#[test]
fn test_ema_reset() {
let mut ema = Ema::new(3).unwrap();
ema.next(2.0).unwrap();
ema.next(4.0).unwrap();
ema.reset_state();
assert_eq!(ema.next(6.0).unwrap(), Some(6.0));
}
#[test]
fn test_ema_calculation_with_candles() {
let mut ema = Ema::new(3).unwrap();
let candles = vec![
Candle {
timestamp: 1,
open: 2.0,
high: 2.5,
low: 1.5,
close: 2.0,
volume: 1000.0,
},
Candle {
timestamp: 2,
open: 4.0,
high: 4.5,
low: 3.5,
close: 4.0,
volume: 1000.0,
},
Candle {
timestamp: 3,
open: 6.0,
high: 6.5,
low: 5.5,
close: 6.0,
volume: 1000.0,
},
Candle {
timestamp: 4,
open: 8.0,
high: 8.5,
low: 7.5,
close: 8.0,
volume: 1000.0,
},
Candle {
timestamp: 5,
open: 10.0,
high: 10.5,
low: 9.5,
close: 10.0,
volume: 1000.0,
},
];
let result = ema.calculate(&candles).unwrap();
let closes: Vec<f64> = candles.iter().map(|c| c.close).collect();
assert_eq!(result.len(), closes.len());
let alpha = 0.5;
let mut expected = vec![closes[0]];
for &v in &closes[1..] {
let prev = *expected.last().unwrap();
expected.push((v - prev) * alpha + prev);
}
for (i, (&got, &want)) in result.iter().zip(expected.iter()).enumerate() {
assert!(
(got - want).abs() < 1e-12,
"row {i}: got {got}, expected {want}"
);
}
}
#[test]
fn test_ema_next_with_candles() {
let mut ema = Ema::new(3).unwrap();
let alpha = 0.5;
let candle1 = Candle {
timestamp: 1,
open: 2.0,
high: 2.5,
low: 1.5,
close: 2.0,
volume: 1000.0,
};
assert_eq!(ema.next(candle1).unwrap(), Some(2.0));
let candle2 = Candle {
timestamp: 2,
open: 4.0,
high: 4.5,
low: 3.5,
close: 4.0,
volume: 1000.0,
};
let expected1 = 4.0 * alpha + 2.0 * (1.0 - alpha); assert_eq!(ema.next(candle2).unwrap(), Some(expected1));
let candle3 = Candle {
timestamp: 3,
open: 6.0,
high: 6.5,
low: 5.5,
close: 6.0,
volume: 1000.0,
};
let expected2 = 6.0 * alpha + expected1 * (1.0 - alpha); assert_eq!(ema.next(candle3).unwrap(), Some(expected2));
}
#[test]
fn test_ema_reset_with_candles() {
let mut ema = Ema::new(3).unwrap();
let candle1 = Candle {
timestamp: 1,
open: 2.0,
high: 2.5,
low: 1.5,
close: 2.0,
volume: 1000.0,
};
let candle2 = Candle {
timestamp: 2,
open: 4.0,
high: 4.5,
low: 3.5,
close: 4.0,
volume: 1000.0,
};
ema.next(candle1).unwrap();
ema.next(candle2).unwrap();
ema.reset_state();
let candle3 = Candle {
timestamp: 3,
open: 6.0,
high: 6.5,
low: 5.5,
close: 6.0,
volume: 1000.0,
};
assert_eq!(ema.next(candle3).unwrap(), Some(6.0));
}
#[test]
fn test_ema_implementations_produce_same_results() {
let mut ema_f64 = Ema::new(3).unwrap();
let mut ema_candle = Ema::new(3).unwrap();
let prices = vec![2.0, 4.0, 6.0, 8.0, 10.0];
let candles = vec![
Candle {
timestamp: 1,
open: 2.0,
high: 2.5,
low: 1.5,
close: 2.0,
volume: 1000.0,
},
Candle {
timestamp: 2,
open: 4.0,
high: 4.5,
low: 3.5,
close: 4.0,
volume: 1000.0,
},
Candle {
timestamp: 3,
open: 6.0,
high: 6.5,
low: 5.5,
close: 6.0,
volume: 1000.0,
},
Candle {
timestamp: 4,
open: 8.0,
high: 8.5,
low: 7.5,
close: 8.0,
volume: 1000.0,
},
Candle {
timestamp: 5,
open: 10.0,
high: 10.5,
low: 9.5,
close: 10.0,
volume: 1000.0,
},
];
let result_f64 = ema_f64.calculate(&prices).unwrap();
let result_candle = ema_candle.calculate(&candles).unwrap();
assert_eq!(result_f64.len(), result_candle.len());
for (val_f64, val_candle) in result_f64.iter().zip(result_candle.iter()) {
assert!((val_f64 - val_candle).abs() < 0.000001);
}
}
#[test]
fn test_ema_next_implementations_produce_same_results() {
let mut ema_f64 = Ema::new(3).unwrap();
let mut ema_candle = Ema::new(3).unwrap();
assert_eq!(
ema_f64.next(2.0).unwrap(),
ema_candle
.next(Candle {
timestamp: 1,
open: 2.0,
high: 2.5,
low: 1.5,
close: 2.0,
volume: 1000.0
})
.unwrap()
);
assert_eq!(
ema_f64.next(4.0).unwrap(),
ema_candle
.next(Candle {
timestamp: 2,
open: 4.0,
high: 4.5,
low: 3.5,
close: 4.0,
volume: 1000.0
})
.unwrap()
);
assert_eq!(
ema_f64.next(6.0).unwrap(),
ema_candle
.next(Candle {
timestamp: 3,
open: 6.0,
high: 6.5,
low: 5.5,
close: 6.0,
volume: 1000.0
})
.unwrap()
);
}
}