Skip to main content

indicators/momentum/
stochastic_rsi.rs

1//! Stochastic RSI oscillator.
2//!
3//! Python source: `indicators/momentum/stochastic_rsi.py :: class StochasticRSI`
4//!
5//! # Algorithm
6//!
7//! 1. Compute RSI series (Wilder's method) with `rsi_period`.
8//! 2. Apply the Stochastic formula to the RSI values over `stoch_period`:
9//!    `%K_raw[i] = 100 * (rsi[i] - min_rsi) / (max_rsi - min_rsi)`
10//! 3. Smooth %K with SMA over `k_smooth` bars.
11//! 4. %D = SMA of smooth %K over `d_period` bars.
12//!
13//! Output columns: `"StochRSI_K"`, `"StochRSI_D"`.
14
15use std::collections::HashMap;
16
17use crate::error::IndicatorError;
18use crate::indicator::{Indicator, IndicatorOutput};
19use crate::momentum::rsi::{Rsi, RsiParams};
20use crate::registry::param_usize;
21use crate::types::Candle;
22
23// ── Params ────────────────────────────────────────────────────────────────────
24
25#[derive(Debug, Clone)]
26pub struct StochRsiParams {
27    /// RSI period. Default: 14.
28    pub rsi_period: usize,
29    /// Rolling window over RSI values for stochastic. Default: 14.
30    pub stoch_period: usize,
31    /// SMA smoothing of raw %K. Default: 3.
32    pub k_smooth: usize,
33    /// SMA of smooth %K for %D. Default: 3.
34    pub d_period: usize,
35}
36
37impl Default for StochRsiParams {
38    fn default() -> Self {
39        Self {
40            rsi_period: 14,
41            stoch_period: 14,
42            k_smooth: 3,
43            d_period: 3,
44        }
45    }
46}
47
48// ── Indicator struct ──────────────────────────────────────────────────────────
49
50#[derive(Debug, Clone)]
51pub struct StochasticRsi {
52    pub params: StochRsiParams,
53}
54
55impl StochasticRsi {
56    pub fn new(params: StochRsiParams) -> Self {
57        Self { params }
58    }
59}
60
61impl Default for StochasticRsi {
62    fn default() -> Self {
63        Self::new(StochRsiParams::default())
64    }
65}
66
67impl Indicator for StochasticRsi {
68    fn name(&self) -> &'static str {
69        "StochasticRSI"
70    }
71
72    fn required_len(&self) -> usize {
73        // RSI needs rsi_period+1 bars; stochastic then needs stoch_period RSI values;
74        // then k_smooth and d_period smoothing on top.
75        self.params.rsi_period
76            + 1
77            + self.params.stoch_period
78            + self.params.k_smooth
79            + self.params.d_period
80            - 2
81    }
82
83    fn required_columns(&self) -> &[&'static str] {
84        &["close"]
85    }
86
87    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
88        self.check_len(candles)?;
89
90        let n = candles.len();
91        let rsi_p = self.params.rsi_period;
92        let stoch_p = self.params.stoch_period;
93        let ks = self.params.k_smooth;
94        let dp = self.params.d_period;
95
96        // ── Step 1: RSI series ────────────────────────────────────────────────
97        let rsi_out = Rsi::new(RsiParams {
98            period: rsi_p,
99            ..Default::default()
100        })
101        .calculate(candles)?;
102        let rsi_key = format!("RSI_{rsi_p}");
103        let rsi: &[f64] = rsi_out
104            .get(&rsi_key)
105            .ok_or_else(|| IndicatorError::InvalidParam("RSI output missing".into()))?;
106
107        // ── Step 2: Stochastic of RSI ─────────────────────────────────────────
108        let mut raw_k = vec![f64::NAN; n];
109        for i in (stoch_p - 1)..n {
110            // Window must be fully non-NaN.
111            let window = &rsi[(i + 1 - stoch_p)..=i];
112            if window.iter().any(|v| v.is_nan()) {
113                continue;
114            }
115            let min_r = window.iter().copied().fold(f64::INFINITY, f64::min);
116            let max_r = window.iter().copied().fold(f64::NEG_INFINITY, f64::max);
117            let range = max_r - min_r;
118            raw_k[i] = if range == 0.0 {
119                50.0
120            }
121            // flat RSI → neutral %K
122            else {
123                100.0 * (rsi[i] - min_r) / range
124            };
125        }
126
127        // ── Step 3: smooth %K ─────────────────────────────────────────────────
128        let smooth_k = if ks <= 1 {
129            raw_k.clone()
130        } else {
131            sma_of(&raw_k, ks)
132        };
133
134        // ── Step 4: %D ────────────────────────────────────────────────────────
135        let d = sma_of(&smooth_k, dp);
136
137        Ok(IndicatorOutput::from_pairs([
138            ("StochRSI_K".to_string(), smooth_k),
139            ("StochRSI_D".to_string(), d),
140        ]))
141    }
142}
143
144fn sma_of(src: &[f64], period: usize) -> Vec<f64> {
145    let n = src.len();
146    let mut out = vec![f64::NAN; n];
147    let mut consecutive = 0usize;
148    for i in 0..n {
149        if src[i].is_nan() {
150            consecutive = 0;
151        } else {
152            consecutive += 1;
153            if consecutive >= period {
154                let sum: f64 = src[(i + 1 - period)..=i].iter().sum();
155                out[i] = sum / period as f64;
156            }
157        }
158    }
159    out
160}
161
162// ── Registry factory ──────────────────────────────────────────────────────────
163
164pub fn factory<S: ::std::hash::BuildHasher>(
165    params: &HashMap<String, String, S>,
166) -> Result<Box<dyn Indicator>, IndicatorError> {
167    Ok(Box::new(StochasticRsi::new(StochRsiParams {
168        rsi_period: param_usize(params, "rsi_period", 14)?,
169        stoch_period: param_usize(params, "stoch_period", 14)?,
170        k_smooth: param_usize(params, "k_smooth", 3)?,
171        d_period: param_usize(params, "d_period", 3)?,
172    })))
173}
174
175// ── Tests ─────────────────────────────────────────────────────────────────────
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    fn make_candles(closes: &[f64]) -> Vec<Candle> {
182        closes
183            .iter()
184            .enumerate()
185            .map(|(i, &c)| Candle {
186                time: i64::try_from(i).expect("time index fits i64"),
187                open: c,
188                high: c,
189                low: c,
190                close: c,
191                volume: 1.0,
192            })
193            .collect()
194    }
195
196    #[test]
197    fn stochrsi_insufficient_data() {
198        let err = StochasticRsi::default()
199            .calculate(&make_candles(&[1.0; 10]))
200            .unwrap_err();
201        assert!(matches!(err, IndicatorError::InsufficientData { .. }));
202    }
203
204    #[test]
205    fn stochrsi_output_columns_exist() {
206        let needed = StochasticRsi::default().required_len();
207        let prices: Vec<f64> = (0..needed + 5)
208            .map(|i| 100.0 + (i as f64 * 0.4).sin() * 5.0)
209            .collect();
210        let out = StochasticRsi::default()
211            .calculate(&make_candles(&prices))
212            .unwrap();
213        assert!(out.get("StochRSI_K").is_some());
214        assert!(out.get("StochRSI_D").is_some());
215    }
216
217    #[test]
218    fn stochrsi_range_0_to_100() {
219        let needed = StochasticRsi::default().required_len();
220        let prices: Vec<f64> = (0..needed + 20)
221            .map(|i| 100.0 + (i as f64 * 0.25).sin() * 8.0)
222            .collect();
223        let out = StochasticRsi::default()
224            .calculate(&make_candles(&prices))
225            .unwrap();
226        for &v in out.get("StochRSI_K").unwrap() {
227            if !v.is_nan() {
228                assert!((0.0..=100.0).contains(&v), "K out of range: {v}");
229            }
230        }
231        for &v in out.get("StochRSI_D").unwrap() {
232            if !v.is_nan() {
233                assert!((0.0..=100.0).contains(&v), "D out of range: {v}");
234            }
235        }
236    }
237
238    #[test]
239    fn stochrsi_constant_prices_neutral() {
240        // Constant closes → RSI=50 everywhere → StochRSI range=0 → %K=50 (flat-RSI guard).
241        let needed = StochasticRsi::default().required_len();
242        let prices = vec![100.0_f64; needed + 5];
243        let out = StochasticRsi::default()
244            .calculate(&make_candles(&prices))
245            .unwrap();
246        let k = out.get("StochRSI_K").unwrap();
247        for &v in k.iter().filter(|v| !v.is_nan()) {
248            assert!((v - 50.0).abs() < 1e-9, "expected 50.0 (neutral), got {v}");
249        }
250    }
251
252    #[test]
253    fn stochrsi_d_lags_k() {
254        // %D is a 3-bar SMA of %K so it must have fewer non-NaN values than %K.
255        let needed = StochasticRsi::default().required_len();
256        let prices: Vec<f64> = (0..needed + 10)
257            .map(|i| 100.0 + (i as f64 * 0.5).sin() * 5.0)
258            .collect();
259        let out = StochasticRsi::default()
260            .calculate(&make_candles(&prices))
261            .unwrap();
262        let k_count = out
263            .get("StochRSI_K")
264            .unwrap()
265            .iter()
266            .filter(|v| !v.is_nan())
267            .count();
268        let d_count = out
269            .get("StochRSI_D")
270            .unwrap()
271            .iter()
272            .filter(|v| !v.is_nan())
273            .count();
274        assert!(d_count <= k_count, "D should have ≤ non-NaN values than K");
275    }
276
277    #[test]
278    fn factory_creates_stochrsi() {
279        let ind = factory(&HashMap::new()).unwrap();
280        assert_eq!(ind.name(), "StochasticRSI");
281    }
282}