1use std::collections::HashMap;
20
21use crate::error::IndicatorError;
22use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
23use crate::registry::{param_usize, param_str};
24use crate::types::Candle;
25
26#[derive(Debug, Clone)]
29pub struct WmaParams {
30 pub period: usize,
32 pub column: PriceColumn,
34}
35
36impl Default for WmaParams {
37 fn default() -> Self {
38 Self {
39 period: 14,
40 column: PriceColumn::Close,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
48pub struct Wma {
49 pub params: WmaParams,
50}
51
52impl Wma {
53 pub fn new(params: WmaParams) -> Self {
54 Self { params }
55 }
56
57 pub fn with_period(period: usize) -> Self {
58 Self::new(WmaParams {
59 period,
60 ..Default::default()
61 })
62 }
63
64 fn output_key(&self) -> String {
65 format!("WMA_{}", self.params.period)
66 }
67}
68
69impl Indicator for Wma {
70 fn name(&self) -> &str {
71 "WMA"
72 }
73 fn required_len(&self) -> usize {
74 self.params.period
75 }
76 fn required_columns(&self) -> &[&'static str] {
77 &["close"]
78 }
79
80 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
87 self.check_len(candles)?;
88
89 let prices = self.params.column.extract(candles);
90 let period = self.params.period;
91 let n = prices.len();
92 let weight_sum = (period * (period + 1) / 2) as f64;
93
94 let mut values = vec![f64::NAN; n];
95
96 for i in (period - 1)..n {
98 let window = &prices[(i + 1 - period)..=i];
99 let weighted: f64 = window
100 .iter()
101 .enumerate()
102 .map(|(j, &p)| (j + 1) as f64 * p)
103 .sum();
104 values[i] = weighted / weight_sum;
105 }
106
107 Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
108 }
109}
110
111pub fn factory(params: &HashMap<String, String>) -> Result<Box<dyn Indicator>, IndicatorError> {
114 let period = param_usize(params, "period", 14)?;
115 let column = match param_str(params, "column", "close") {
116 "open" => PriceColumn::Open,
117 "high" => PriceColumn::High,
118 "low" => PriceColumn::Low,
119 _ => PriceColumn::Close,
120 };
121 Ok(Box::new(Wma::new(WmaParams { period, column })))
122}
123
124#[cfg(test)]
127mod tests {
128 use super::*;
129
130 fn candles(closes: &[f64]) -> Vec<Candle> {
131 closes.iter().enumerate().map(|(i, &c)| Candle {
132 time: i as i64, open: c, high: c, low: c, close: c, volume: 1.0,
133 }).collect()
134 }
135
136 #[test]
137 fn wma_insufficient_data() {
138 assert!(Wma::with_period(5).calculate(&candles(&[1.0, 2.0])).is_err());
139 }
140
141 #[test]
142 fn wma_period3_known_value() {
143 let out = Wma::with_period(3).calculate(&candles(&[1.0, 2.0, 3.0])).unwrap();
145 let vals = out.get("WMA_3").unwrap();
146 let expected = (1.0 * 1.0 + 2.0 * 2.0 + 3.0 * 3.0) / 6.0;
147 assert!((vals[2] - expected).abs() < 1e-9, "got {}", vals[2]);
148 }
149
150 #[test]
151 fn wma_leading_nans() {
152 let out = Wma::with_period(3).calculate(&candles(&[1.0, 2.0, 3.0, 4.0])).unwrap();
153 let vals = out.get("WMA_3").unwrap();
154 assert!(vals[0].is_nan());
155 assert!(vals[1].is_nan());
156 assert!(!vals[2].is_nan());
157 }
158
159 #[test]
160 fn factory_creates_wma() {
161 let params = [("period".into(), "10".into())].into();
162 assert_eq!(factory(¶ms).unwrap().name(), "WMA");
163 }
164}