1use std::collections::HashMap;
20
21use crate::error::IndicatorError;
22use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
23use crate::registry::{param_str, param_usize};
24use crate::types::Candle;
25
26#[derive(Debug, Clone)]
29pub struct WmaParams {
30 pub period: usize,
32 pub column: PriceColumn,
34}
35
36impl Default for WmaParams {
37 fn default() -> Self {
38 Self {
39 period: 14,
40 column: PriceColumn::Close,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
48pub struct Wma {
49 pub params: WmaParams,
50}
51
52impl Wma {
53 pub fn new(params: WmaParams) -> Self {
54 Self { params }
55 }
56
57 pub fn with_period(period: usize) -> Self {
58 Self::new(WmaParams {
59 period,
60 ..Default::default()
61 })
62 }
63
64 fn output_key(&self) -> String {
65 format!("WMA_{}", self.params.period)
66 }
67}
68
69impl Indicator for Wma {
70 fn name(&self) -> &'static str {
71 "WMA"
72 }
73 fn required_len(&self) -> usize {
74 self.params.period
75 }
76 fn required_columns(&self) -> &[&'static str] {
77 &["close"]
78 }
79
80 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
86 self.check_len(candles)?;
87
88 let prices = self.params.column.extract(candles);
89 let period = self.params.period;
90 let n = prices.len();
91 let weight_sum = (period * (period + 1) / 2) as f64;
92
93 let mut values = vec![f64::NAN; n];
94
95 for i in (period - 1)..n {
96 let window = &prices[(i + 1 - period)..=i];
97 let weighted: f64 = window
98 .iter()
99 .enumerate()
100 .map(|(j, &p)| (j + 1) as f64 * p)
101 .sum();
102 values[i] = weighted / weight_sum;
103 }
104
105 Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
106 }
107}
108
109pub fn factory<S: ::std::hash::BuildHasher>(
112 params: &HashMap<String, String, S>,
113) -> Result<Box<dyn Indicator>, IndicatorError> {
114 let period = param_usize(params, "period", 14)?;
115 let column = match param_str(params, "column", "close") {
116 "open" => PriceColumn::Open,
117 "high" => PriceColumn::High,
118 "low" => PriceColumn::Low,
119 _ => PriceColumn::Close,
120 };
121 Ok(Box::new(Wma::new(WmaParams { period, column })))
122}
123
124#[cfg(test)]
127mod tests {
128 use super::*;
129
130 fn candles(closes: &[f64]) -> Vec<Candle> {
131 closes
132 .iter()
133 .enumerate()
134 .map(|(i, &c)| Candle {
135 time: i64::try_from(i).expect("time index fits i64"),
136 open: c,
137 high: c,
138 low: c,
139 close: c,
140 volume: 1.0,
141 })
142 .collect()
143 }
144
145 #[test]
146 fn wma_insufficient_data() {
147 assert!(
148 Wma::with_period(5)
149 .calculate(&candles(&[1.0, 2.0]))
150 .is_err()
151 );
152 }
153
154 #[test]
155 fn wma_period3_known_value() {
156 let out = Wma::with_period(3)
158 .calculate(&candles(&[1.0, 2.0, 3.0]))
159 .unwrap();
160 let vals = out.get("WMA_3").unwrap();
161 let expected = (1.0 * 1.0 + 2.0 * 2.0 + 3.0 * 3.0) / 6.0;
162 assert!((vals[2] - expected).abs() < 1e-9, "got {}", vals[2]);
163 }
164
165 #[test]
166 fn wma_leading_nans() {
167 let out = Wma::with_period(3)
168 .calculate(&candles(&[1.0, 2.0, 3.0, 4.0]))
169 .unwrap();
170 let vals = out.get("WMA_3").unwrap();
171 assert!(vals[0].is_nan());
172 assert!(vals[1].is_nan());
173 assert!(!vals[2].is_nan());
174 }
175
176 #[test]
177 fn factory_creates_wma() {
178 let params = [("period".into(), "10".into())].into();
179 assert_eq!(factory(¶ms).unwrap().name(), "WMA");
180 }
181}