use crate::{KandError, TAFloat};
pub const fn lookback(param_period: usize) -> Result<usize, KandError> {
#[cfg(feature = "check")]
{
if param_period < 2 {
return Err(KandError::InvalidParameter);
}
}
Ok(param_period - 1)
}
pub fn sma(
input: &[TAFloat],
param_period: usize,
output_sma: &mut [TAFloat],
) -> Result<(), KandError> {
let len = input.len();
let lookback = lookback(param_period)?;
#[cfg(feature = "check")]
{
if len == 0 {
return Err(KandError::InvalidData);
}
if output_sma.len() != len {
return Err(KandError::LengthMismatch);
}
if len <= lookback {
return Err(KandError::InsufficientData);
}
}
#[cfg(feature = "deep-check")]
{
for price in input {
if price.is_nan() {
return Err(KandError::NaNDetected);
}
}
}
let mut sum = input[0];
for value in input.iter().take(lookback + 1).skip(1) {
sum += *value;
}
output_sma[lookback] = sum / param_period as TAFloat;
for i in (lookback + 1)..input.len() {
sum = sum + input[i] - input[i - param_period];
output_sma[i] = sum / param_period as TAFloat;
}
for value in output_sma.iter_mut().take(lookback) {
*value = TAFloat::NAN;
}
Ok(())
}
pub fn sma_inc(
prev_sma: TAFloat,
input_new_price: TAFloat,
input_old_price: TAFloat,
param_period: usize,
) -> Result<TAFloat, KandError> {
#[cfg(feature = "check")]
{
if param_period < 2 {
return Err(KandError::InvalidParameter);
}
}
#[cfg(feature = "deep-check")]
{
if prev_sma.is_nan() || input_new_price.is_nan() || input_old_price.is_nan() {
return Err(KandError::NaNDetected);
}
}
Ok(prev_sma + (input_new_price - input_old_price) / param_period as TAFloat)
}
#[cfg(test)]
mod tests {
use approx::assert_relative_eq;
use super::*;
#[test]
fn test_sma_calculation() {
let input = vec![
35216.1, 35221.4, 35190.7, 35170.0, 35181.5, 35254.6, 35202.8, 35251.9, 35197.6,
35184.7, 35175.1, 35229.9, 35212.5, 35160.7, 35090.3, 35041.2, 34999.3, 35013.4,
35069.0, 35024.6, 34939.5, 34952.6, 35000.0, 35041.8, 35080.0, 35114.5, 35097.2,
35092.0, 35073.2, 35139.3, 35092.0, 35126.7, 35106.3, 35124.8, 35170.1, 35215.3,
35154.0, 35216.3, 35211.8,
];
let param_period = 14;
let mut output_sma = vec![0.0; input.len()];
sma(&input, param_period, &mut output_sma).unwrap();
for value in output_sma.iter().take(13) {
assert!(value.is_nan());
}
let expected_values = [
35_203.535_714_285_72,
35194.55,
35_181.678_571_428_57,
35_168.007_142_857_15,
35_156.821_428_571_435,
35_148.785_714_285_72,
35_132.357_142_857_145,
35113.55,
35_092.171_428_571_43,
35_078.057_142_857_15,
35067.85,
35_061.057_142_857_15,
35_052.814_285_714_29,
];
for (i, expected) in expected_values.iter().enumerate() {
assert_relative_eq!(output_sma[i + 13], *expected, epsilon = 0.00001);
}
let mut prev_sma = output_sma[13];
for i in 14..17 {
let result =
sma_inc(prev_sma, input[i], input[i - param_period], param_period).unwrap();
assert_relative_eq!(result, output_sma[i], epsilon = 0.00001);
prev_sma = result;
}
}
}