use crate::{KandError, TAFloat};
pub const fn lookback(param_period: usize) -> Result<usize, KandError> {
#[cfg(feature = "check")]
{
if param_period < 2 {
return Err(KandError::InvalidParameter);
}
}
Ok(param_period - 1)
}
pub fn wma(
input: &[TAFloat],
param_period: usize,
output: &mut [TAFloat],
) -> Result<(), KandError> {
let len = input.len();
let lookback = lookback(param_period)?;
#[cfg(feature = "check")]
{
if len == 0 {
return Err(KandError::InvalidData);
}
if len != output.len() {
return Err(KandError::LengthMismatch);
}
if len <= lookback {
return Err(KandError::InsufficientData);
}
}
#[cfg(feature = "deep-check")]
{
for &value in input {
if value.is_nan() {
return Err(KandError::NaNDetected);
}
}
}
let denominator = (param_period * (param_period + 1)) as TAFloat / 2.0;
for value in output.iter_mut().take(lookback) {
*value = TAFloat::NAN;
}
for i in lookback..len {
let mut weighted_sum = 0.0;
let mut weight = param_period as TAFloat;
for j in 0..param_period {
weighted_sum += input[i - j] * weight;
weight -= 1.0;
}
output[i] = weighted_sum / denominator;
}
Ok(())
}
pub fn wma_inc(input_window: &[TAFloat], param_period: usize) -> Result<TAFloat, KandError> {
#[cfg(feature = "check")]
{
if param_period < 2 {
return Err(KandError::InvalidParameter);
}
if input_window.len() != param_period {
return Err(KandError::LengthMismatch);
}
}
#[cfg(feature = "deep-check")]
{
for &value in input_window {
if value.is_nan() {
return Err(KandError::NaNDetected);
}
}
}
let denominator = (param_period * (param_period + 1)) as TAFloat / 2.0;
let mut weighted_sum = 0.0;
let mut weight = param_period as TAFloat;
for &value in input_window {
weighted_sum += value * weight;
weight -= 1.0;
}
Ok(weighted_sum / denominator)
}
#[cfg(test)]
mod tests {
use approx::assert_relative_eq;
use super::*;
#[test]
fn test_wma_calculation() {
let input = vec![
35216.1, 35221.4, 35190.7, 35170.0, 35181.5, 35254.6, 35202.8, 35251.9, 35197.6,
35184.7, 35175.1, 35229.9, 35212.5, 35160.7, 35090.3, 35041.2, 34999.3, 35013.4,
35069.0, 35024.6, 34939.5, 34952.6, 35000.0, 35041.8, 35080.0, 35114.5, 35097.2,
35092.0, 35073.2, 35139.3, 35092.0, 35126.7, 35106.3, 35124.8, 35170.1, 35215.3,
35154.0, 35216.3, 35211.8, 35158.4,
];
let param_period = 30;
let mut output = vec![0.0; input.len()];
wma(&input, param_period, &mut output).unwrap();
for value in output.iter().take(29) {
assert!(value.is_nan());
}
let expected_values = [
35_086.706_666_666_67,
35_084.862_795_698_93,
35_085.524_516_129_04,
35_085.073_763_440_865,
35_085.998_064_516_134,
35_089.942_150_537_645,
35_096.826_881_720_44,
35_099.841_290_322_58,
35106.98,
35_113.904_946_236_566,
35_117.354_193_548_395,
];
for (i, expected) in expected_values.iter().enumerate() {
assert_relative_eq!(output[i + 29], *expected, epsilon = 0.0001);
}
for i in 30..35 {
let window: Vec<TAFloat> = input[i - (param_period - 1)..=i]
.iter()
.rev()
.copied()
.collect();
let result = wma_inc(&window, param_period).unwrap();
assert_relative_eq!(result, output[i], epsilon = 0.0001);
}
}
}