Skip to main content

indicators/momentum/
williams_r.rs

1//! Williams %R.
2//!
3//! Python source: `indicators/other/williams_r.py :: class WilliamsRIndicator`
4//!
5//! # Python algorithm (to port)
6//! ```python
7//! highest_high = data["High"].rolling(window=self.period).max()
8//! lowest_low   = data["Low"].rolling(window=self.period).min()
9//! will_r       = -100 * (highest_high - data["Close"]) / (highest_high - lowest_low)
10//! ```
11//!
12//! Oscillates between -100 and 0.  Above -20 → overbought; below -80 → oversold.
13//!
14//! Output column: `"WR_{period}"`.
15
16use std::collections::HashMap;
17
18use crate::error::IndicatorError;
19use crate::indicator::{Indicator, IndicatorOutput};
20use crate::registry::param_usize;
21use crate::types::Candle;
22
23#[derive(Debug, Clone)]
24pub struct WrParams {
25    pub period: usize,
26}
27impl Default for WrParams {
28    fn default() -> Self {
29        Self { period: 14 }
30    }
31}
32
33#[derive(Debug, Clone)]
34pub struct WilliamsR {
35    pub params: WrParams,
36}
37
38impl WilliamsR {
39    pub fn new(params: WrParams) -> Self {
40        Self { params }
41    }
42    pub fn with_period(period: usize) -> Self {
43        Self::new(WrParams { period })
44    }
45    fn output_key(&self) -> String {
46        format!("WR_{}", self.params.period)
47    }
48}
49
50impl Indicator for WilliamsR {
51    fn name(&self) -> &'static str {
52        "WilliamsR"
53    }
54    fn required_len(&self) -> usize {
55        self.params.period
56    }
57    fn required_columns(&self) -> &[&'static str] {
58        &["high", "low", "close"]
59    }
60
61    /// Ports `-100 * (highest_high - close) / (highest_high - lowest_low)`.
62    ///
63    /// When `highest_high == lowest_low` Python produces `NaN` via float
64    /// division by zero; the Rust guards this explicitly with a
65    /// `range == 0.0` check.  Both paths produce `NaN` for that bar.
66    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
67        self.check_len(candles)?;
68
69        let n = candles.len();
70        let p = self.params.period;
71        let mut values = vec![f64::NAN; n];
72
73        for i in (p - 1)..n {
74            let window = &candles[(i + 1 - p)..=i];
75            let highest_h = window
76                .iter()
77                .map(|c| c.high)
78                .fold(f64::NEG_INFINITY, f64::max);
79            let lowest_l = window.iter().map(|c| c.low).fold(f64::INFINITY, f64::min);
80            let range = highest_h - lowest_l;
81            values[i] = if range == 0.0 {
82                f64::NAN
83            } else {
84                -100.0 * (highest_h - candles[i].close) / range
85            };
86        }
87
88        Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
89    }
90}
91
92pub fn factory<S: ::std::hash::BuildHasher>(
93    params: &HashMap<String, String, S>,
94) -> Result<Box<dyn Indicator>, IndicatorError> {
95    Ok(Box::new(WilliamsR::new(WrParams {
96        period: param_usize(params, "period", 14)?,
97    })))
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    fn candles(data: &[(f64, f64, f64)]) -> Vec<Candle> {
105        data.iter()
106            .enumerate()
107            .map(|(i, &(h, l, c))| Candle {
108                time: i64::try_from(i).expect("time index fits i64"),
109                open: c,
110                high: h,
111                low: l,
112                close: c,
113                volume: 1.0,
114            })
115            .collect()
116    }
117
118    fn rising(n: usize) -> Vec<Candle> {
119        (0..n)
120            .map(|i| {
121                let f = i as f64;
122                Candle {
123                    time: i64::try_from(i).expect("time index fits i64"),
124                    open: f,
125                    high: f + 1.0,
126                    low: f - 1.0,
127                    close: f + 0.5,
128                    volume: 1.0,
129                }
130            })
131            .collect()
132    }
133
134    #[test]
135    fn wr_range_neg100_to_0() {
136        let out = WilliamsR::with_period(14).calculate(&rising(20)).unwrap();
137        for &v in out.get("WR_14").unwrap() {
138            if !v.is_nan() {
139                assert!((-100.0..=0.0).contains(&v), "out of range: {v}");
140            }
141        }
142    }
143
144    #[test]
145    fn wr_close_at_high_is_zero() {
146        // close == highest_high → WR = 0.
147        let bars = vec![(12.0f64, 8.0, 12.0); 14];
148        let bars = candles(&bars);
149        let out = WilliamsR::with_period(14).calculate(&bars).unwrap();
150        let vals = out.get("WR_14").unwrap();
151        assert!((vals[13] - 0.0).abs() < 1e-9, "got {}", vals[13]);
152    }
153
154    #[test]
155    fn wr_close_at_low_is_neg100() {
156        let bars = vec![(12.0f64, 8.0, 8.0); 14];
157        let bars = candles(&bars);
158        let out = WilliamsR::with_period(14).calculate(&bars).unwrap();
159        let vals = out.get("WR_14").unwrap();
160        assert!((vals[13] - (-100.0)).abs() < 1e-9, "got {}", vals[13]);
161    }
162
163    #[test]
164    fn factory_creates_wr() {
165        assert_eq!(factory(&HashMap::new()).unwrap().name(), "WilliamsR");
166    }
167}