quantaxis_rs/indicators/
simple_moving_average.rs

1use std::fmt;
2
3use crate::errors::*;
4use crate::{Close, Next, Reset};
5
6/// Simple moving average (SMA).
7///
8/// # Formula
9///
10/// ![SMA](https://wikimedia.org/api/rest_v1/media/math/render/svg/e2bf09dc6deaf86b3607040585fac6078f9c7c89)
11///
12/// Where:
13///
14/// * _SMA<sub>t</sub>_ - value of simple moving average at a point of time _t_
15/// * _n_ - number of periods (length)
16/// * _p<sub>t</sub>_ - input value at a point of time _t_
17///
18/// # Parameters
19///
20/// * _n_ - number of periods (integer greater than 0)
21///
22/// # Example
23///
24/// ```
25/// use quantaxis_rs::indicators::SimpleMovingAverage;
26/// use quantaxis_rs::Next;
27///
28/// let mut sma = SimpleMovingAverage::new(3).unwrap();
29/// assert_eq!(sma.next(10.0), 10.0);
30/// assert_eq!(sma.next(11.0), 10.5);
31/// assert_eq!(sma.next(12.0), 11.0);
32/// assert_eq!(sma.next(13.0), 12.0);
33/// ```
34///
35
36#[derive(Debug, Clone)]
37pub struct SimpleMovingAverage {
38    n: u32,
39    index: usize,
40    count: u32,
41    sum: f64,
42    vec: Vec<f64>,
43}
44
45impl SimpleMovingAverage {
46    pub fn new(n: u32) -> Result<Self> {
47        match n {
48            0 => Err(Error::from_kind(ErrorKind::InvalidParameter)),
49            _ => {
50                let indicator = Self {
51                    n: n,
52                    index: 0,
53                    count: 0,
54                    sum: 0.0,
55                    vec: vec![0.0; n as usize],
56                };
57                Ok(indicator)
58            }
59        }
60    }
61}
62
63impl Next<f64> for SimpleMovingAverage {
64    type Output = f64;
65
66    fn next(&mut self, input: f64) -> Self::Output {
67        self.index = (self.index + 1) % (self.n as usize);
68
69        let old_val = self.vec[self.index];
70        self.vec[self.index] = input;
71
72        if self.count < self.n {
73            self.count += 1;
74        }
75
76        self.sum = self.sum - old_val + input;
77        self.sum / (self.count as f64)
78    }
79}
80
81impl<'a, T: Close> Next<&'a T> for SimpleMovingAverage {
82    type Output = f64;
83
84    fn next(&mut self, input: &'a T) -> Self::Output {
85        self.next(input.close())
86    }
87}
88
89impl Reset for SimpleMovingAverage {
90    fn reset(&mut self) {
91        self.index = 0;
92        self.count = 0;
93        self.sum = 0.0;
94        for i in 0..(self.n as usize) {
95            self.vec[i] = 0.0;
96        }
97    }
98}
99
100impl Default for SimpleMovingAverage {
101    fn default() -> Self {
102        Self::new(9).unwrap()
103    }
104}
105
106impl fmt::Display for SimpleMovingAverage {
107    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
108        write!(f, "SMA({})", self.n)
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use crate::test_helper::*;
116    macro_rules! test_indicator {
117        ($i:tt) => {
118            #[test]
119            fn test_indicator() {
120                let bar = Bar::new();
121
122                // ensure Default trait is implemented
123                let mut indicator = $i::default();
124
125                // ensure Next<f64> is implemented
126                let first_output = indicator.next(12.3);
127
128                // ensure next accepts &DataItem as well
129                indicator.next(&bar);
130
131                // ensure Reset is implemented and works correctly
132                indicator.reset();
133                assert_eq!(indicator.next(12.3), first_output);
134
135                // ensure Display is implemented
136                format!("{}", indicator);
137            }
138        };
139    }
140    test_indicator!(SimpleMovingAverage);
141
142    #[test]
143    fn test_new() {
144        assert!(SimpleMovingAverage::new(0).is_err());
145        assert!(SimpleMovingAverage::new(1).is_ok());
146    }
147
148    #[test]
149    fn test_next() {
150        let mut sma = SimpleMovingAverage::new(4).unwrap();
151        assert_eq!(sma.next(4.0), 4.0);
152        assert_eq!(sma.next(5.0), 4.5);
153        assert_eq!(sma.next(6.0), 5.0);
154        assert_eq!(sma.next(6.0), 5.25);
155        assert_eq!(sma.next(6.0), 5.75);
156        assert_eq!(sma.next(6.0), 6.0);
157        assert_eq!(sma.next(2.0), 5.0);
158    }
159
160    #[test]
161    fn test_next_with_bars() {
162        fn bar(close: f64) -> Bar {
163            Bar::new().close(close)
164        }
165
166        let mut sma = SimpleMovingAverage::new(3).unwrap();
167        assert_eq!(sma.next(&bar(4.0)), 4.0);
168        assert_eq!(sma.next(&bar(4.0)), 4.0);
169        assert_eq!(sma.next(&bar(7.0)), 5.0);
170        assert_eq!(sma.next(&bar(1.0)), 4.0);
171    }
172
173    #[test]
174    fn test_reset() {
175        let mut sma = SimpleMovingAverage::new(4).unwrap();
176        assert_eq!(sma.next(4.0), 4.0);
177        assert_eq!(sma.next(5.0), 4.5);
178        assert_eq!(sma.next(6.0), 5.0);
179
180        sma.reset();
181        assert_eq!(sma.next(99.0), 99.0);
182    }
183
184    #[test]
185    fn test_default() {
186        SimpleMovingAverage::default();
187    }
188
189    #[test]
190    fn test_display() {
191        let sma = SimpleMovingAverage::new(5).unwrap();
192        assert_eq!(format!("{}", sma), "SMA(5)");
193    }
194}