Skip to main content

indicators/trend/
wma.rs

1//! Weighted Moving Average (WMA).
2//!
3//! Python source: `indicators/trend/moving_average.py :: class WMA`
4//!              + `indicators/trend/weighted_moving_average.py :: class WMA`
5//!
6//! # Python algorithm (to port)
7//! ```python
8//! weights = np.arange(1, self.period + 1)          # [1, 2, ..., period]
9//! wma = data[self.column].rolling(window=self.period).apply(
10//!     lambda x: np.sum(weights * x) / weights.sum(), raw=True
11//! )
12//! ```
13//!
14//! Weight for index `i` (0-based within window) = `i + 1`.
15//! Denominator = `period * (period + 1) / 2`.
16//!
17//! Output column: `"WMA_{period}"`.
18
19use std::collections::HashMap;
20
21use crate::error::IndicatorError;
22use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
23use crate::registry::{param_str, param_usize};
24use crate::types::Candle;
25
26// ── Params ────────────────────────────────────────────────────────────────────
27
28#[derive(Debug, Clone)]
29pub struct WmaParams {
30    /// Lookback period.  Python default: 14 (weighted_moving_average.py) / 20 (moving_average.py).
31    pub period: usize,
32    /// Price field.  Python default: `"close"`.
33    pub column: PriceColumn,
34}
35
36impl Default for WmaParams {
37    fn default() -> Self {
38        Self {
39            period: 14,
40            column: PriceColumn::Close,
41        }
42    }
43}
44
45// ── Indicator struct ──────────────────────────────────────────────────────────
46
47#[derive(Debug, Clone)]
48pub struct Wma {
49    pub params: WmaParams,
50}
51
52impl Wma {
53    pub fn new(params: WmaParams) -> Self {
54        Self { params }
55    }
56
57    pub fn with_period(period: usize) -> Self {
58        Self::new(WmaParams {
59            period,
60            ..Default::default()
61        })
62    }
63
64    fn output_key(&self) -> String {
65        format!("WMA_{}", self.params.period)
66    }
67}
68
69impl Indicator for Wma {
70    fn name(&self) -> &'static str {
71        "WMA"
72    }
73    fn required_len(&self) -> usize {
74        self.params.period
75    }
76    fn required_columns(&self) -> &[&'static str] {
77        &["close"]
78    }
79
80    /// Ports `rolling(window=period).apply(lambda x: np.sum(weights * x) / weights.sum())`.
81    ///
82    /// Weights are linear: position `j` (0-based within the window) receives
83    /// weight `j + 1`.  Denominator = `period * (period + 1) / 2`.
84    /// Produces `NaN` for the first `period - 1` positions.
85    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
86        self.check_len(candles)?;
87
88        let prices = self.params.column.extract(candles);
89        let period = self.params.period;
90        let n = prices.len();
91        let weight_sum = (period * (period + 1) / 2) as f64;
92
93        let mut values = vec![f64::NAN; n];
94
95        for i in (period - 1)..n {
96            let window = &prices[(i + 1 - period)..=i];
97            let weighted: f64 = window
98                .iter()
99                .enumerate()
100                .map(|(j, &p)| (j + 1) as f64 * p)
101                .sum();
102            values[i] = weighted / weight_sum;
103        }
104
105        Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
106    }
107}
108
109// ── Registry factory ──────────────────────────────────────────────────────────
110
111pub fn factory<S: ::std::hash::BuildHasher>(
112    params: &HashMap<String, String, S>,
113) -> Result<Box<dyn Indicator>, IndicatorError> {
114    let period = param_usize(params, "period", 14)?;
115    let column = match param_str(params, "column", "close") {
116        "open" => PriceColumn::Open,
117        "high" => PriceColumn::High,
118        "low" => PriceColumn::Low,
119        _ => PriceColumn::Close,
120    };
121    Ok(Box::new(Wma::new(WmaParams { period, column })))
122}
123
124// ── Tests ─────────────────────────────────────────────────────────────────────
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    fn candles(closes: &[f64]) -> Vec<Candle> {
131        closes
132            .iter()
133            .enumerate()
134            .map(|(i, &c)| Candle {
135                time: i64::try_from(i).expect("time index fits i64"),
136                open: c,
137                high: c,
138                low: c,
139                close: c,
140                volume: 1.0,
141            })
142            .collect()
143    }
144
145    #[test]
146    fn wma_insufficient_data() {
147        assert!(
148            Wma::with_period(5)
149                .calculate(&candles(&[1.0, 2.0]))
150                .is_err()
151        );
152    }
153
154    #[test]
155    fn wma_period3_known_value() {
156        // weights [1,2,3], sum=6; prices [1,2,3] → (1+4+9)/6 = 14/6 ≈ 2.333
157        let out = Wma::with_period(3)
158            .calculate(&candles(&[1.0, 2.0, 3.0]))
159            .unwrap();
160        let vals = out.get("WMA_3").unwrap();
161        let expected = (1.0 * 1.0 + 2.0 * 2.0 + 3.0 * 3.0) / 6.0;
162        assert!((vals[2] - expected).abs() < 1e-9, "got {}", vals[2]);
163    }
164
165    #[test]
166    fn wma_leading_nans() {
167        let out = Wma::with_period(3)
168            .calculate(&candles(&[1.0, 2.0, 3.0, 4.0]))
169            .unwrap();
170        let vals = out.get("WMA_3").unwrap();
171        assert!(vals[0].is_nan());
172        assert!(vals[1].is_nan());
173        assert!(!vals[2].is_nan());
174    }
175
176    #[test]
177    fn factory_creates_wma() {
178        let params = [("period".into(), "10".into())].into();
179        assert_eq!(factory(&params).unwrap().name(), "WMA");
180    }
181}