finlib_ta/indicators/
weighted_moving_average.rs1use alloc::boxed::Box;
2use alloc::vec;
3use core::fmt;
4
5use crate::errors::{Result, TaError};
6use crate::{Close, Next, Period, Reset};
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9
10#[doc(alias = "WMA")]
44#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
45#[derive(Debug, Clone)]
46pub struct WeightedMovingAverage {
47 period: usize,
48 index: usize,
49 count: usize,
50 weight: f64,
51 sum: f64,
52 sum_flat: f64,
53 deque: Box<[f64]>,
54}
55
56impl WeightedMovingAverage {
57 pub fn new(period: usize) -> Result<Self> {
58 match period {
59 0 => Err(TaError::InvalidParameter),
60 _ => Ok(Self {
61 period,
62 index: 0,
63 count: 0,
64 weight: 0.0,
65 sum: 0.0,
66 sum_flat: 0.0,
67 deque: vec![0.0; period].into_boxed_slice(),
68 }),
69 }
70 }
71}
72
73impl Period for WeightedMovingAverage {
74 fn period(&self) -> usize {
75 self.period
76 }
77}
78
79impl Next<f64> for WeightedMovingAverage {
80 type Output = f64;
81
82 fn next(&mut self, input: f64) -> Self::Output {
83 let old_val: f64 = self.deque[self.index];
84 self.deque[self.index] = input;
85
86 self.index = if self.index + 1 < self.period {
87 self.index + 1
88 } else {
89 0
90 };
91
92 if self.count < self.period {
93 self.count += 1;
94 self.weight = self.count as f64;
95 self.sum += input * self.weight
96 } else {
97 self.sum = self.sum - self.sum_flat + (input * self.weight);
98 }
99 self.sum_flat = self.sum_flat - old_val + input;
100 self.sum / (self.weight * (self.weight + 1.0) / 2.0)
101 }
102}
103
104impl<T: Close> Next<&T> for WeightedMovingAverage {
105 type Output = f64;
106
107 fn next(&mut self, input: &T) -> Self::Output {
108 self.next(input.close())
109 }
110}
111
112impl Reset for WeightedMovingAverage {
113 fn reset(&mut self) {
114 self.index = 0;
115 self.count = 0;
116 self.weight = 0.0;
117 self.sum = 0.0;
118 self.sum_flat = 0.0;
119 for i in 0..self.period {
120 self.deque[i] = 0.0;
121 }
122 }
123}
124
125impl Default for WeightedMovingAverage {
126 fn default() -> Self {
127 Self::new(9).unwrap()
128 }
129}
130
131impl fmt::Display for WeightedMovingAverage {
132 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
133 write!(f, "WMA({})", self.period)
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use crate::test_helper::*;
141 use alloc::format;
142
143 test_indicator!(WeightedMovingAverage);
144
145 #[test]
146 fn test_new() {
147 assert!(WeightedMovingAverage::new(0).is_err());
148 assert!(WeightedMovingAverage::new(1).is_ok());
149 }
150
151 #[test]
152 fn test_next() {
153 let mut wma = WeightedMovingAverage::new(3).unwrap();
154
155 assert_eq!(wma.next(12.0), 12.0);
156 assert_eq!(wma.next(3.0), 6.0); assert_eq!(wma.next(3.0), 4.5); assert_eq!(wma.next(5.0), 4.0); let mut wma = WeightedMovingAverage::new(3).unwrap();
161 let bar1 = Bar::new().close(2);
162 let bar2 = Bar::new().close(5);
163 assert_eq!(wma.next(&bar1), 2.0);
164 assert_eq!(wma.next(&bar2), 4.0);
165 }
166
167 #[test]
168 fn test_reset() {
169 let mut wma = WeightedMovingAverage::new(5).unwrap();
170
171 assert_eq!(wma.next(4.0), 4.0);
172 wma.next(10.0);
173 wma.next(15.0);
174 wma.next(20.0);
175 assert_ne!(wma.next(4.0), 4.0);
176
177 wma.reset();
178 assert_eq!(wma.next(4.0), 4.0);
179 }
180
181 #[test]
182 fn test_default() {
183 WeightedMovingAverage::default();
184 }
185
186 #[test]
187 fn test_display() {
188 let wma = WeightedMovingAverage::new(7).unwrap();
189 assert_eq!(format!("{}", wma), "WMA(7)");
190 }
191}