Skip to main content

indicators/trend/
sma.rs

1//! Simple Moving Average (SMA).
2//!
3//! Python source: `indicators/trend/moving_average.py :: class SMA`
4//!
5//! # Python algorithm (to port)
6//! ```python
7//! sma = data[self.column].rolling(window=self.period).mean()
8//! return pd.DataFrame({f"{self.name}_{self.period}": sma}, index=data.index)
9//! ```
10//!
11//! Output column: `"SMA_{period}"` — e.g. `"SMA_20"`.
12
13use std::collections::HashMap;
14
15use crate::error::IndicatorError;
16use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
17use crate::registry::{param_str, param_usize};
18use crate::types::Candle;
19
20// ── Params ────────────────────────────────────────────────────────────────────
21
22/// Parameters for the SMA indicator.
23///
24/// Mirrors Python: `self.period = params.get("period", 20)` etc.
25#[derive(Debug, Clone)]
26pub struct SmaParams {
27    /// Rolling window size.  Python default: 20.
28    pub period: usize,
29    /// Which OHLCV field to average.  Python default: `"close"`.
30    pub column: PriceColumn,
31}
32
33impl Default for SmaParams {
34    fn default() -> Self {
35        Self {
36            period: 20,
37            column: PriceColumn::Close,
38        }
39    }
40}
41
42// ── Indicator struct ──────────────────────────────────────────────────────────
43
44/// Simple Moving Average.
45///
46/// Calculates the arithmetic mean of prices over a sliding window.
47///
48/// # Example
49/// ```rust,ignore
50/// let sma = Sma::new(SmaParams { period: 20, ..Default::default() });
51/// let output = sma.calculate(&candles)?;
52/// let values = output.get("SMA_20").unwrap();
53/// ```
54#[derive(Debug, Clone)]
55pub struct Sma {
56    pub params: SmaParams,
57}
58
59impl Sma {
60    pub fn new(params: SmaParams) -> Self {
61        Self { params }
62    }
63
64    /// Convenience constructor with just a period.
65    pub fn with_period(period: usize) -> Self {
66        Self::new(SmaParams {
67            period,
68            ..Default::default()
69        })
70    }
71
72    /// Column label used in `IndicatorOutput`.
73    /// Mirrors Python: `f"{self.name}_{self.period}"`.
74    fn output_key(&self) -> String {
75        format!("SMA_{}", self.params.period)
76    }
77}
78
79impl Indicator for Sma {
80    fn name(&self) -> &'static str {
81        "SMA"
82    }
83
84    fn required_len(&self) -> usize {
85        self.params.period
86    }
87
88    fn required_columns(&self) -> &[&'static str] {
89        &["close"] // adjusts if column != Close, but close is the default
90    }
91
92    /// Ports `data[self.column].rolling(window=self.period).mean()`.
93    ///
94    /// Produces `NaN` for the first `period - 1` positions (matching pandas
95    /// `rolling(window=N)` default of `min_periods=N`), then the arithmetic
96    /// mean of the most recent `period` prices.
97    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
98        self.check_len(candles)?;
99
100        let prices = self.params.column.extract(candles);
101        let period = self.params.period;
102        let n = prices.len();
103
104        let mut values = vec![f64::NAN; n];
105
106        for i in (period - 1)..n {
107            let sum: f64 = prices[(i + 1 - period)..=i].iter().sum();
108            values[i] = sum / period as f64;
109        }
110
111        Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
112    }
113}
114
115// ── Registry factory ──────────────────────────────────────────────────────────
116
117/// Factory function registered under `"sma"` in the global registry.
118pub fn factory<S: ::std::hash::BuildHasher>(
119    params: &HashMap<String, String, S>,
120) -> Result<Box<dyn Indicator>, IndicatorError> {
121    let period = param_usize(params, "period", 20)?;
122    let column = match param_str(params, "column", "close") {
123        "open" => PriceColumn::Open,
124        "high" => PriceColumn::High,
125        "low" => PriceColumn::Low,
126        "volume" => PriceColumn::Volume,
127        _ => PriceColumn::Close,
128    };
129    Ok(Box::new(Sma::new(SmaParams { period, column })))
130}
131
132// ── Tests ─────────────────────────────────────────────────────────────────────
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use crate::types::Candle;
138
139    fn make_candles(closes: &[f64]) -> Vec<Candle> {
140        closes
141            .iter()
142            .enumerate()
143            .map(|(i, &c)| Candle {
144                time: i64::try_from(i).expect("time index fits i64"),
145                open: c,
146                high: c,
147                low: c,
148                close: c,
149                volume: 1.0,
150            })
151            .collect()
152    }
153
154    #[test]
155    fn sma_insufficient_data() {
156        let sma = Sma::with_period(5);
157        let err = sma.calculate(&make_candles(&[1.0, 2.0])).unwrap_err();
158        assert!(matches!(err, IndicatorError::InsufficientData { .. }));
159    }
160
161    #[test]
162    fn sma_output_key() {
163        let sma = Sma::with_period(20);
164        assert_eq!(sma.output_key(), "SMA_20");
165    }
166
167    #[test]
168    fn sma_first_value_is_nan() {
169        let closes = vec![10.0, 11.0, 12.0, 13.0, 14.0];
170        let sma = Sma::with_period(5);
171        let out = sma.calculate(&make_candles(&closes)).unwrap();
172        let vals = out.get("SMA_5").unwrap();
173        assert!(vals[0].is_nan());
174        assert!(vals[3].is_nan());
175    }
176
177    #[test]
178    fn sma_last_value_correct() {
179        // SMA(3) of [10, 20, 30] = 20
180        let closes = vec![10.0, 20.0, 30.0];
181        let sma = Sma::with_period(3);
182        let out = sma.calculate(&make_candles(&closes)).unwrap();
183        let vals = out.get("SMA_3").unwrap();
184        assert!(
185            (vals[2] - 20.0).abs() < 1e-9,
186            "expected 20.0, got {}",
187            vals[2]
188        );
189    }
190
191    #[test]
192    fn sma_rolling_window() {
193        // [1,2,3,4,5], period=3 → NaN, NaN, 2.0, 3.0, 4.0
194        let closes = vec![1.0, 2.0, 3.0, 4.0, 5.0];
195        let sma = Sma::with_period(3);
196        let out = sma.calculate(&make_candles(&closes)).unwrap();
197        let vals = out.get("SMA_3").unwrap();
198        assert!((vals[2] - 2.0).abs() < 1e-9);
199        assert!((vals[3] - 3.0).abs() < 1e-9);
200        assert!((vals[4] - 4.0).abs() < 1e-9);
201    }
202
203    #[test]
204    fn factory_creates_sma() {
205        let params = [("period".into(), "10".into())].into();
206        let ind = factory(&params).unwrap();
207        assert_eq!(ind.name(), "SMA");
208        assert_eq!(ind.required_len(), 10);
209    }
210}