1use std::collections::HashMap;
14
15use crate::error::IndicatorError;
16use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
17use crate::registry::{param_usize, param_str};
18use crate::types::Candle;
19
20#[derive(Debug, Clone)]
26pub struct SmaParams {
27 pub period: usize,
29 pub column: PriceColumn,
31}
32
33impl Default for SmaParams {
34 fn default() -> Self {
35 Self {
36 period: 20,
37 column: PriceColumn::Close,
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
55pub struct Sma {
56 pub params: SmaParams,
57}
58
59impl Sma {
60 pub fn new(params: SmaParams) -> Self {
61 Self { params }
62 }
63
64 pub fn with_period(period: usize) -> Self {
66 Self::new(SmaParams { period, ..Default::default() })
67 }
68
69 fn output_key(&self) -> String {
72 format!("SMA_{}", self.params.period)
73 }
74}
75
76impl Indicator for Sma {
77 fn name(&self) -> &str {
78 "SMA"
79 }
80
81 fn required_len(&self) -> usize {
82 self.params.period
83 }
84
85 fn required_columns(&self) -> &[&'static str] {
86 &["close"] }
88
89 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
96 self.check_len(candles)?;
97
98 let prices = self.params.column.extract(candles);
99 let period = self.params.period;
100 let n = prices.len();
101
102 let mut values = vec![f64::NAN; n];
103
104 for i in (period - 1)..n {
106 let sum: f64 = prices[(i + 1 - period)..=i].iter().sum();
107 values[i] = sum / period as f64;
108 }
109
110 Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
111 }
112}
113
114pub fn factory(params: &HashMap<String, String>) -> Result<Box<dyn Indicator>, IndicatorError> {
118 let period = param_usize(params, "period", 20)?;
119 let column = match param_str(params, "column", "close") {
120 "open" => PriceColumn::Open,
121 "high" => PriceColumn::High,
122 "low" => PriceColumn::Low,
123 "volume" => PriceColumn::Volume,
124 _ => PriceColumn::Close,
125 };
126 Ok(Box::new(Sma::new(SmaParams { period, column })))
127}
128
129#[cfg(test)]
132mod tests {
133 use super::*;
134 use crate::types::Candle;
135
136 fn make_candles(closes: &[f64]) -> Vec<Candle> {
137 closes
138 .iter()
139 .enumerate()
140 .map(|(i, &c)| Candle {
141 time: i as i64,
142 open: c,
143 high: c,
144 low: c,
145 close: c,
146 volume: 1.0,
147 })
148 .collect()
149 }
150
151 #[test]
152 fn sma_insufficient_data() {
153 let sma = Sma::with_period(5);
154 let err = sma.calculate(&make_candles(&[1.0, 2.0])).unwrap_err();
155 assert!(matches!(err, IndicatorError::InsufficientData { .. }));
156 }
157
158 #[test]
159 fn sma_output_key() {
160 let sma = Sma::with_period(20);
161 assert_eq!(sma.output_key(), "SMA_20");
162 }
163
164 #[test]
165 fn sma_first_value_is_nan() {
166 let closes = vec![10.0, 11.0, 12.0, 13.0, 14.0];
167 let sma = Sma::with_period(5);
168 let out = sma.calculate(&make_candles(&closes)).unwrap();
169 let vals = out.get("SMA_5").unwrap();
170 assert!(vals[0].is_nan());
171 assert!(vals[3].is_nan());
172 }
173
174 #[test]
175 fn sma_last_value_correct() {
176 let closes = vec![10.0, 20.0, 30.0];
178 let sma = Sma::with_period(3);
179 let out = sma.calculate(&make_candles(&closes)).unwrap();
180 let vals = out.get("SMA_3").unwrap();
181 assert!((vals[2] - 20.0).abs() < 1e-9, "expected 20.0, got {}", vals[2]);
182 }
183
184 #[test]
185 fn sma_rolling_window() {
186 let closes = vec![1.0, 2.0, 3.0, 4.0, 5.0];
188 let sma = Sma::with_period(3);
189 let out = sma.calculate(&make_candles(&closes)).unwrap();
190 let vals = out.get("SMA_3").unwrap();
191 assert!((vals[2] - 2.0).abs() < 1e-9);
192 assert!((vals[3] - 3.0).abs() < 1e-9);
193 assert!((vals[4] - 4.0).abs() < 1e-9);
194 }
195
196 #[test]
197 fn factory_creates_sma() {
198 let params = [("period".into(), "10".into())].into();
199 let ind = factory(¶ms).unwrap();
200 assert_eq!(ind.name(), "SMA");
201 assert_eq!(ind.required_len(), 10);
202 }
203}