use std::collections::HashMap;
use crate::error::IndicatorError;
use crate::functions::{self};
use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
use crate::registry::{param_f64, param_str, param_usize};
use crate::types::Candle;
#[derive(Debug, Clone)]
pub struct EmaParams {
pub period: usize,
pub alpha: Option<f64>,
pub column: PriceColumn,
}
impl Default for EmaParams {
fn default() -> Self {
Self {
period: 20,
alpha: None,
column: PriceColumn::Close,
}
}
}
impl EmaParams {
fn effective_alpha(&self) -> f64 {
self.alpha
.unwrap_or_else(|| 2.0 / (self.period as f64 + 1.0))
}
}
#[derive(Debug, Clone)]
pub struct Ema {
pub params: EmaParams,
}
impl Ema {
pub fn new(params: EmaParams) -> Self {
Self { params }
}
pub fn with_period(period: usize) -> Self {
Self::new(EmaParams {
period,
..Default::default()
})
}
fn output_key(&self) -> String {
format!("EMA_{}", self.params.period)
}
}
impl Indicator for Ema {
fn name(&self) -> &'static str {
"EMA"
}
fn required_len(&self) -> usize {
self.params.period
}
fn required_columns(&self) -> &[&'static str] {
&["close"]
}
fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
self.check_len(candles)?;
let prices = self.params.column.extract(candles);
let _alpha = self.params.effective_alpha();
let _n = prices.len();
let period = self.params.period;
let values = functions::ema(&prices, period)?;
Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
}
}
pub fn factory<S: ::std::hash::BuildHasher>(params: &HashMap<String, String, S>) -> Result<Box<dyn Indicator>, IndicatorError> {
let period = param_usize(params, "period", 20)?;
let alpha = if params.contains_key("alpha") {
Some(param_f64(params, "alpha", 2.0 / (period as f64 + 1.0))?)
} else {
None
};
let column = match param_str(params, "column", "close") {
"open" => PriceColumn::Open,
"high" => PriceColumn::High,
"low" => PriceColumn::Low,
_ => PriceColumn::Close,
};
Ok(Box::new(Ema::new(EmaParams {
period,
alpha,
column,
})))
}
#[cfg(test)]
mod tests {
use super::*;
fn candles(closes: &[f64]) -> Vec<Candle> {
closes
.iter()
.enumerate()
.map(|(i, &c)| Candle {
time: i64::try_from(i).expect("time index fits i64"),
open: c,
high: c,
low: c,
close: c,
volume: 1.0,
})
.collect()
}
#[test]
fn ema_insufficient_data() {
let ema = Ema::with_period(5);
assert!(ema.calculate(&candles(&[1.0, 2.0])).is_err());
}
#[test]
fn ema_output_column_named_correctly() {
let ema = Ema::with_period(3);
let out = ema.calculate(&candles(&[10.0, 20.0, 30.0])).unwrap();
assert!(out.get("EMA_3").is_some());
}
#[test]
fn ema_seed_equals_sma() {
let closes = vec![10.0, 20.0, 30.0];
let ema = Ema::with_period(3);
let out = ema.calculate(&candles(&closes)).unwrap();
let vals = out.get("EMA_3").unwrap();
let expected_seed = (10.0 + 20.0 + 30.0) / 3.0;
assert!((vals[2] - expected_seed).abs() < 1e-9, "got {}", vals[2]);
}
#[test]
fn ema_subsequent_value() {
let closes = vec![10.0, 20.0, 30.0, 40.0];
let ema = Ema::with_period(3);
let out = ema.calculate(&candles(&closes)).unwrap();
let vals = out.get("EMA_3").unwrap();
let expected = 40.0 * 0.5 + 20.0 * 0.5;
assert!((vals[3] - expected).abs() < 1e-6, "got {}", vals[3]);
}
#[test]
fn factory_creates_ema() {
let params = [("period".into(), "12".into())].into();
let ind = factory(¶ms).unwrap();
assert_eq!(ind.name(), "EMA");
}
}