use crate::bar_indicators::indicator_value::IndicatorValue;
use crate::bar_indicators::ohlcv_field::OhlcvField;
const WMA_MAX_PERIOD: usize = 256;
#[derive(Debug, Clone)]
pub struct Wma {
period: usize,
buf: Vec<f64>, idx: usize, filled: bool,
weight_sum: f64, weighted_sum: f64, unweighted_sum: f64, value: f64,
source: OhlcvField,
}
impl Wma {
pub fn new(period: usize) -> Self {
Self::with_source(period, OhlcvField::Close)
}
pub fn with_source(period: usize, source: OhlcvField) -> Self {
assert!(period <= WMA_MAX_PERIOD, "period {} exceeds WMA_MAX_PERIOD {}", period, WMA_MAX_PERIOD);
let period = period.max(1);
let weight_sum = (period * (period + 1)) as f64 / 2.0;
Self {
period,
buf: Vec::with_capacity(period),
idx: 0,
filled: false,
weight_sum,
weighted_sum: 0.0,
unweighted_sum: 0.0,
value: 0.0,
source,
}
}
pub fn period(&self) -> usize {
self.period
}
pub fn value(&self) -> IndicatorValue {
IndicatorValue::Single(self.value)
}
pub fn is_ready(&self) -> bool {
self.filled
}
pub fn reset(&mut self) {
self.buf.clear();
self.idx = 0;
self.filled = false;
self.weighted_sum = 0.0;
self.unweighted_sum = 0.0;
self.value = 0.0;
}
pub fn update_bar(&mut self, open: f64, high: f64, low: f64, close: f64, volume: f64) -> f64 {
let value = self.source.extract(open, high, low, close, volume);
if self.buf.len() < self.period {
self.buf.push(value);
self.unweighted_sum += value;
if self.buf.len() == self.period {
self.filled = true;
self.weighted_sum = 0.0;
for (i, &val) in self.buf.iter().enumerate() {
let weight = (i + 1) as f64; self.weighted_sum += val * weight;
}
self.value = self.weighted_sum / self.weight_sum;
} else {
self.value = value; }
} else {
let old_value = self.buf[self.idx];
self.weighted_sum = self.weighted_sum - self.unweighted_sum + (self.period as f64) * value;
self.unweighted_sum = self.unweighted_sum - old_value + value;
self.buf[self.idx] = value;
self.idx = (self.idx + 1) % self.period;
self.value = self.weighted_sum / self.weight_sum;
}
self.value
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wma_correctness() {
let mut wma = Wma::new(5);
wma.update_bar(0.0, 0.0, 0.0, 10.0, 0.0);
wma.update_bar(0.0, 0.0, 0.0, 20.0, 0.0);
wma.update_bar(0.0, 0.0, 0.0, 30.0, 0.0);
wma.update_bar(0.0, 0.0, 0.0, 40.0, 0.0);
let result1 = wma.update_bar(0.0, 0.0, 0.0, 50.0, 0.0);
assert!((result1 - 36.666666).abs() < 0.001, "Expected ~36.67, got {}", result1);
let result2 = wma.update_bar(0.0, 0.0, 0.0, 60.0, 0.0);
assert!((result2 - 46.666666).abs() < 0.001, "Expected ~46.67, got {}", result2);
}
#[test]
fn test_wma_period_1() {
let mut wma = Wma::new(1);
assert_eq!(wma.update_bar(0.0, 0.0, 0.0, 10.0, 0.0), 10.0);
assert_eq!(wma.update_bar(0.0, 0.0, 0.0, 20.0, 0.0), 20.0);
assert_eq!(wma.update_bar(0.0, 0.0, 0.0, 30.0, 0.0), 30.0);
}
#[test]
fn test_wma_reset() {
let mut wma = Wma::new(3);
wma.update_bar(0.0, 0.0, 0.0, 10.0, 0.0);
wma.update_bar(0.0, 0.0, 0.0, 20.0, 0.0);
wma.update_bar(0.0, 0.0, 0.0, 30.0, 0.0);
assert!(wma.is_ready());
wma.reset();
assert!(!wma.is_ready());
}
#[test]
fn test_wma_with_source_hl2() {
let mut wma = Wma::with_source(3, OhlcvField::HL2);
wma.update_bar(0.0, 110.0, 90.0, 105.0, 0.0);
wma.update_bar(0.0, 120.0, 80.0, 110.0, 0.0);
let result = wma.update_bar(0.0, 130.0, 70.0, 115.0, 0.0);
assert!((result - 100.0).abs() < 1e-10);
}
#[test]
fn test_wma_with_source_high() {
let mut wma = Wma::with_source(3, OhlcvField::High);
wma.update_bar(100.0, 110.0, 90.0, 105.0, 0.0);
wma.update_bar(105.0, 120.0, 95.0, 110.0, 0.0);
let result = wma.update_bar(110.0, 130.0, 100.0, 115.0, 0.0);
assert!((result - 123.333333).abs() < 0.001);
}
}