finlib_ta/indicators/
simple_moving_average.rs1use core::fmt;
2
3use crate::errors::{Result, TaError};
4use crate::{Close, Next, Period, Reset};
5use alloc::boxed::Box;
6use alloc::vec;
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9
10#[doc(alias = "SMA")]
44#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
45#[derive(Debug, Clone)]
46pub struct SimpleMovingAverage {
47 period: usize,
48 index: usize,
49 count: usize,
50 sum: f64,
51 deque: Box<[f64]>,
52}
53
54impl SimpleMovingAverage {
55 pub fn new(period: usize) -> Result<Self> {
56 match period {
57 0 => Err(TaError::InvalidParameter),
58 _ => Ok(Self {
59 period,
60 index: 0,
61 count: 0,
62 sum: 0.0,
63 deque: vec![0.0; period].into_boxed_slice(),
64 }),
65 }
66 }
67}
68
69impl Period for SimpleMovingAverage {
70 fn period(&self) -> usize {
71 self.period
72 }
73}
74
75impl Next<f64> for SimpleMovingAverage {
76 type Output = f64;
77
78 fn next(&mut self, input: f64) -> Self::Output {
79 let old_val = self.deque[self.index];
80 self.deque[self.index] = input;
81
82 self.index = if self.index + 1 < self.period {
83 self.index + 1
84 } else {
85 0
86 };
87
88 if self.count < self.period {
89 self.count += 1;
90 }
91
92 self.sum = self.sum - old_val + input;
93 self.sum / (self.count as f64)
94 }
95}
96
97impl<T: Close> Next<&T> for SimpleMovingAverage {
98 type Output = f64;
99
100 fn next(&mut self, input: &T) -> Self::Output {
101 self.next(input.close())
102 }
103}
104
105impl Reset for SimpleMovingAverage {
106 fn reset(&mut self) {
107 self.index = 0;
108 self.count = 0;
109 self.sum = 0.0;
110 for i in 0..self.period {
111 self.deque[i] = 0.0;
112 }
113 }
114}
115
116impl Default for SimpleMovingAverage {
117 fn default() -> Self {
118 Self::new(9).unwrap()
119 }
120}
121
122impl fmt::Display for SimpleMovingAverage {
123 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
124 write!(f, "SMA({})", self.period)
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use crate::test_helper::*;
132 use alloc::format;
133
134 test_indicator!(SimpleMovingAverage);
135
136 #[test]
137 fn test_new() {
138 assert!(SimpleMovingAverage::new(0).is_err());
139 assert!(SimpleMovingAverage::new(1).is_ok());
140 }
141
142 #[test]
143 fn test_next() {
144 let mut sma = SimpleMovingAverage::new(4).unwrap();
145 assert_eq!(sma.next(4.0), 4.0);
146 assert_eq!(sma.next(5.0), 4.5);
147 assert_eq!(sma.next(6.0), 5.0);
148 assert_eq!(sma.next(6.0), 5.25);
149 assert_eq!(sma.next(6.0), 5.75);
150 assert_eq!(sma.next(6.0), 6.0);
151 assert_eq!(sma.next(2.0), 5.0);
152 }
153
154 #[test]
155 fn test_next_with_bars() {
156 fn bar(close: f64) -> Bar {
157 Bar::new().close(close)
158 }
159
160 let mut sma = SimpleMovingAverage::new(3).unwrap();
161 assert_eq!(sma.next(&bar(4.0)), 4.0);
162 assert_eq!(sma.next(&bar(4.0)), 4.0);
163 assert_eq!(sma.next(&bar(7.0)), 5.0);
164 assert_eq!(sma.next(&bar(1.0)), 4.0);
165 }
166
167 #[test]
168 fn test_reset() {
169 let mut sma = SimpleMovingAverage::new(4).unwrap();
170 assert_eq!(sma.next(4.0), 4.0);
171 assert_eq!(sma.next(5.0), 4.5);
172 assert_eq!(sma.next(6.0), 5.0);
173
174 sma.reset();
175 assert_eq!(sma.next(99.0), 99.0);
176 }
177
178 #[test]
179 fn test_default() {
180 SimpleMovingAverage::default();
181 }
182
183 #[test]
184 fn test_display() {
185 let sma = SimpleMovingAverage::new(5).unwrap();
186 assert_eq!(format!("{}", sma), "SMA(5)");
187 }
188}