finlib_ta/indicators/
simple_moving_average.rs

1use core::fmt;
2
3use crate::errors::{Result, TaError};
4use crate::{Close, Next, Period, Reset};
5use alloc::boxed::Box;
6use alloc::vec;
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9
10/// Simple moving average (SMA).
11///
12/// # Formula
13///
14/// ![SMA](https://wikimedia.org/api/rest_v1/media/math/render/svg/e2bf09dc6deaf86b3607040585fac6078f9c7c89)
15///
16/// Where:
17///
18/// * _SMA<sub>t</sub>_ - value of simple moving average at a point of time _t_
19/// * _period_ - number of periods (period)
20/// * _p<sub>t</sub>_ - input value at a point of time _t_
21///
22/// # Parameters
23///
24/// * _period_ - number of periods (integer greater than 0)
25///
26/// # Example
27///
28/// ```
29/// use finlib_ta::indicators::SimpleMovingAverage;
30/// use finlib_ta::Next;
31///
32/// let mut sma = SimpleMovingAverage::new(3).unwrap();
33/// assert_eq!(sma.next(10.0), 10.0);
34/// assert_eq!(sma.next(11.0), 10.5);
35/// assert_eq!(sma.next(12.0), 11.0);
36/// assert_eq!(sma.next(13.0), 12.0);
37/// ```
38///
39/// # Links
40///
41/// * [Simple Moving Average, Wikipedia](https://en.wikipedia.org/wiki/Moving_average#Simple_moving_average)
42///
43#[doc(alias = "SMA")]
44#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
45#[derive(Debug, Clone)]
46pub struct SimpleMovingAverage {
47    period: usize,
48    index: usize,
49    count: usize,
50    sum: f64,
51    deque: Box<[f64]>,
52}
53
54impl SimpleMovingAverage {
55    pub fn new(period: usize) -> Result<Self> {
56        match period {
57            0 => Err(TaError::InvalidParameter),
58            _ => Ok(Self {
59                period,
60                index: 0,
61                count: 0,
62                sum: 0.0,
63                deque: vec![0.0; period].into_boxed_slice(),
64            }),
65        }
66    }
67}
68
69impl Period for SimpleMovingAverage {
70    fn period(&self) -> usize {
71        self.period
72    }
73}
74
75impl Next<f64> for SimpleMovingAverage {
76    type Output = f64;
77
78    fn next(&mut self, input: f64) -> Self::Output {
79        let old_val = self.deque[self.index];
80        self.deque[self.index] = input;
81
82        self.index = if self.index + 1 < self.period {
83            self.index + 1
84        } else {
85            0
86        };
87
88        if self.count < self.period {
89            self.count += 1;
90        }
91
92        self.sum = self.sum - old_val + input;
93        self.sum / (self.count as f64)
94    }
95}
96
97impl<T: Close> Next<&T> for SimpleMovingAverage {
98    type Output = f64;
99
100    fn next(&mut self, input: &T) -> Self::Output {
101        self.next(input.close())
102    }
103}
104
105impl Reset for SimpleMovingAverage {
106    fn reset(&mut self) {
107        self.index = 0;
108        self.count = 0;
109        self.sum = 0.0;
110        for i in 0..self.period {
111            self.deque[i] = 0.0;
112        }
113    }
114}
115
116impl Default for SimpleMovingAverage {
117    fn default() -> Self {
118        Self::new(9).unwrap()
119    }
120}
121
122impl fmt::Display for SimpleMovingAverage {
123    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
124        write!(f, "SMA({})", self.period)
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use crate::test_helper::*;
132    use alloc::format;
133
134    test_indicator!(SimpleMovingAverage);
135
136    #[test]
137    fn test_new() {
138        assert!(SimpleMovingAverage::new(0).is_err());
139        assert!(SimpleMovingAverage::new(1).is_ok());
140    }
141
142    #[test]
143    fn test_next() {
144        let mut sma = SimpleMovingAverage::new(4).unwrap();
145        assert_eq!(sma.next(4.0), 4.0);
146        assert_eq!(sma.next(5.0), 4.5);
147        assert_eq!(sma.next(6.0), 5.0);
148        assert_eq!(sma.next(6.0), 5.25);
149        assert_eq!(sma.next(6.0), 5.75);
150        assert_eq!(sma.next(6.0), 6.0);
151        assert_eq!(sma.next(2.0), 5.0);
152    }
153
154    #[test]
155    fn test_next_with_bars() {
156        fn bar(close: f64) -> Bar {
157            Bar::new().close(close)
158        }
159
160        let mut sma = SimpleMovingAverage::new(3).unwrap();
161        assert_eq!(sma.next(&bar(4.0)), 4.0);
162        assert_eq!(sma.next(&bar(4.0)), 4.0);
163        assert_eq!(sma.next(&bar(7.0)), 5.0);
164        assert_eq!(sma.next(&bar(1.0)), 4.0);
165    }
166
167    #[test]
168    fn test_reset() {
169        let mut sma = SimpleMovingAverage::new(4).unwrap();
170        assert_eq!(sma.next(4.0), 4.0);
171        assert_eq!(sma.next(5.0), 4.5);
172        assert_eq!(sma.next(6.0), 5.0);
173
174        sma.reset();
175        assert_eq!(sma.next(99.0), 99.0);
176    }
177
178    #[test]
179    fn test_default() {
180        SimpleMovingAverage::default();
181    }
182
183    #[test]
184    fn test_display() {
185        let sma = SimpleMovingAverage::new(5).unwrap();
186        assert_eq!(format!("{}", sma), "SMA(5)");
187    }
188}