use crate::{KandError, TAFloat, helper::period_to_k};
pub const fn lookback(param_period: usize) -> Result<usize, KandError> {
#[cfg(feature = "check")]
{
if param_period < 2 {
return Err(KandError::InvalidParameter);
}
}
Ok(param_period - 1)
}
pub fn ema(
input_prices: &[TAFloat],
param_period: usize,
param_k: Option<TAFloat>,
output_ema: &mut [TAFloat],
) -> Result<(), KandError> {
let len = input_prices.len();
let lookback = lookback(param_period)?;
#[cfg(feature = "check")]
{
if len == 0 {
return Err(KandError::InvalidData);
}
if len <= lookback {
return Err(KandError::InsufficientData);
}
if output_ema.len() != len {
return Err(KandError::LengthMismatch);
}
}
#[cfg(feature = "deep-check")]
{
for price in input_prices {
if price.is_nan() {
return Err(KandError::NaNDetected);
}
}
}
let mut sum = input_prices[0];
for value in input_prices.iter().take(param_period).skip(1) {
sum += *value;
}
let mut prev_ma = sum / (param_period as TAFloat);
output_ema[lookback] = prev_ma;
let multiplier = match param_k {
Some(k) => k,
None => period_to_k(param_period)?,
};
for i in param_period..len {
prev_ma = (input_prices[i] - prev_ma).mul_add(multiplier, prev_ma);
output_ema[i] = prev_ma;
}
for value in output_ema.iter_mut().take(lookback) {
*value = TAFloat::NAN;
}
Ok(())
}
pub fn ema_inc(
input_price: TAFloat,
prev_ema: TAFloat,
param_period: usize,
param_k: Option<TAFloat>,
) -> Result<TAFloat, KandError> {
#[cfg(feature = "check")]
{
if param_period < 2 {
return Err(KandError::InvalidParameter);
}
}
#[cfg(feature = "deep-check")]
{
if input_price.is_nan() || prev_ema.is_nan() {
return Err(KandError::NaNDetected);
}
}
let multiplier = match param_k {
Some(k) => k,
None => period_to_k(param_period)?,
};
Ok((input_price - prev_ema).mul_add(multiplier, prev_ema))
}
#[cfg(test)]
mod tests {
use approx::assert_relative_eq;
use super::*;
#[test]
fn test_ema_calculation() {
let input_prices = vec![
35216.1, 35221.4, 35190.7, 35170.0, 35181.5, 35254.6, 35202.8, 35251.9, 35197.6,
35184.7, 35175.1, 35229.9, 35212.5, 35160.7, 35090.3, 35041.2, 34999.3, 35013.4,
35069.0, 35024.6, 34939.5, 34952.6, 35000.0, 35041.8, 35080.0, 35114.5, 35097.2,
35092.0, 35073.2, 35139.3, 35092.0, 35126.7, 35106.3, 35124.8, 35170.1, 35215.3,
];
let param_period = 14;
let mut output_ema = vec![0.0; input_prices.len()];
ema(&input_prices, param_period, None, &mut output_ema).unwrap();
for value in output_ema.iter().take(13) {
assert!(value.is_nan());
}
let expected_values = [
35_203.535_714_285_72,
35_188.437_619_047_625,
35_168.805_936_507_94,
35_146.205_144_973_545,
35_128.497_792_310_41,
35_120.564_753_335_69,
35_107.769_452_890_934,
35_085.333_525_838_81,
35_067.635_722_393_636,
35_058.617_626_074_48,
35_056.375_275_931_22,
35_059.525_239_140_39,
35_066.855_207_255_01,
];
for (i, expected) in expected_values.iter().enumerate() {
assert_relative_eq!(output_ema[i + 13], *expected, epsilon = 0.00001);
}
let param_period = 14;
let mut output_ema = vec![0.0; input_prices.len()];
ema(&input_prices, param_period, None, &mut output_ema).unwrap();
let mut prev_ema = output_ema[13];
for i in 14..18 {
let result = ema_inc(input_prices[i], prev_ema, param_period, None).unwrap();
assert_relative_eq!(result, output_ema[i], epsilon = 0.00001);
prev_ema = result;
}
}
}