1use std::collections::HashMap;
14
15use crate::error::IndicatorError;
16use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
17use crate::registry::{param_str, param_usize};
18use crate::types::Candle;
19
20#[derive(Debug, Clone)]
26pub struct SmaParams {
27 pub period: usize,
29 pub column: PriceColumn,
31}
32
33impl Default for SmaParams {
34 fn default() -> Self {
35 Self {
36 period: 20,
37 column: PriceColumn::Close,
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
55pub struct Sma {
56 pub params: SmaParams,
57}
58
59impl Sma {
60 pub fn new(params: SmaParams) -> Self {
61 Self { params }
62 }
63
64 pub fn with_period(period: usize) -> Self {
66 Self::new(SmaParams {
67 period,
68 ..Default::default()
69 })
70 }
71
72 fn output_key(&self) -> String {
75 format!("SMA_{}", self.params.period)
76 }
77}
78
79impl Indicator for Sma {
80 fn name(&self) -> &'static str {
81 "SMA"
82 }
83
84 fn required_len(&self) -> usize {
85 self.params.period
86 }
87
88 fn required_columns(&self) -> &[&'static str] {
89 &["close"] }
91
92 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
98 self.check_len(candles)?;
99
100 let prices = self.params.column.extract(candles);
101 let period = self.params.period;
102 let n = prices.len();
103
104 let mut values = vec![f64::NAN; n];
105
106 for i in (period - 1)..n {
107 let sum: f64 = prices[(i + 1 - period)..=i].iter().sum();
108 values[i] = sum / period as f64;
109 }
110
111 Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
112 }
113}
114
115pub fn factory<S: ::std::hash::BuildHasher>(
119 params: &HashMap<String, String, S>,
120) -> Result<Box<dyn Indicator>, IndicatorError> {
121 let period = param_usize(params, "period", 20)?;
122 let column = match param_str(params, "column", "close") {
123 "open" => PriceColumn::Open,
124 "high" => PriceColumn::High,
125 "low" => PriceColumn::Low,
126 "volume" => PriceColumn::Volume,
127 _ => PriceColumn::Close,
128 };
129 Ok(Box::new(Sma::new(SmaParams { period, column })))
130}
131
132#[cfg(test)]
135mod tests {
136 use super::*;
137 use crate::types::Candle;
138
139 fn make_candles(closes: &[f64]) -> Vec<Candle> {
140 closes
141 .iter()
142 .enumerate()
143 .map(|(i, &c)| Candle {
144 time: i64::try_from(i).expect("time index fits i64"),
145 open: c,
146 high: c,
147 low: c,
148 close: c,
149 volume: 1.0,
150 })
151 .collect()
152 }
153
154 #[test]
155 fn sma_insufficient_data() {
156 let sma = Sma::with_period(5);
157 let err = sma.calculate(&make_candles(&[1.0, 2.0])).unwrap_err();
158 assert!(matches!(err, IndicatorError::InsufficientData { .. }));
159 }
160
161 #[test]
162 fn sma_output_key() {
163 let sma = Sma::with_period(20);
164 assert_eq!(sma.output_key(), "SMA_20");
165 }
166
167 #[test]
168 fn sma_first_value_is_nan() {
169 let closes = vec![10.0, 11.0, 12.0, 13.0, 14.0];
170 let sma = Sma::with_period(5);
171 let out = sma.calculate(&make_candles(&closes)).unwrap();
172 let vals = out.get("SMA_5").unwrap();
173 assert!(vals[0].is_nan());
174 assert!(vals[3].is_nan());
175 }
176
177 #[test]
178 fn sma_last_value_correct() {
179 let closes = vec![10.0, 20.0, 30.0];
181 let sma = Sma::with_period(3);
182 let out = sma.calculate(&make_candles(&closes)).unwrap();
183 let vals = out.get("SMA_3").unwrap();
184 assert!(
185 (vals[2] - 20.0).abs() < 1e-9,
186 "expected 20.0, got {}",
187 vals[2]
188 );
189 }
190
191 #[test]
192 fn sma_rolling_window() {
193 let closes = vec![1.0, 2.0, 3.0, 4.0, 5.0];
195 let sma = Sma::with_period(3);
196 let out = sma.calculate(&make_candles(&closes)).unwrap();
197 let vals = out.get("SMA_3").unwrap();
198 assert!((vals[2] - 2.0).abs() < 1e-9);
199 assert!((vals[3] - 3.0).abs() < 1e-9);
200 assert!((vals[4] - 4.0).abs() < 1e-9);
201 }
202
203 #[test]
204 fn factory_creates_sma() {
205 let params = [("period".into(), "10".into())].into();
206 let ind = factory(¶ms).unwrap();
207 assert_eq!(ind.name(), "SMA");
208 assert_eq!(ind.required_len(), 10);
209 }
210}