indicators/momentum/
stochastic_rsi.rs1use std::collections::HashMap;
16
17use crate::error::IndicatorError;
18use crate::indicator::{Indicator, IndicatorOutput};
19use crate::momentum::rsi::{Rsi, RsiParams};
20use crate::registry::param_usize;
21use crate::types::Candle;
22
23#[derive(Debug, Clone)]
26pub struct StochRsiParams {
27 pub rsi_period: usize,
29 pub stoch_period: usize,
31 pub k_smooth: usize,
33 pub d_period: usize,
35}
36
37impl Default for StochRsiParams {
38 fn default() -> Self {
39 Self {
40 rsi_period: 14,
41 stoch_period: 14,
42 k_smooth: 3,
43 d_period: 3,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
51pub struct StochasticRsi {
52 pub params: StochRsiParams,
53}
54
55impl StochasticRsi {
56 pub fn new(params: StochRsiParams) -> Self {
57 Self { params }
58 }
59}
60
61impl Default for StochasticRsi {
62 fn default() -> Self {
63 Self::new(StochRsiParams::default())
64 }
65}
66
67impl Indicator for StochasticRsi {
68 fn name(&self) -> &'static str {
69 "StochasticRSI"
70 }
71
72 fn required_len(&self) -> usize {
73 self.params.rsi_period
76 + 1
77 + self.params.stoch_period
78 + self.params.k_smooth
79 + self.params.d_period
80 - 2
81 }
82
83 fn required_columns(&self) -> &[&'static str] {
84 &["close"]
85 }
86
87 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
88 self.check_len(candles)?;
89
90 let n = candles.len();
91 let rsi_p = self.params.rsi_period;
92 let stoch_p = self.params.stoch_period;
93 let ks = self.params.k_smooth;
94 let dp = self.params.d_period;
95
96 let rsi_out = Rsi::new(RsiParams {
98 period: rsi_p,
99 ..Default::default()
100 })
101 .calculate(candles)?;
102 let rsi_key = format!("RSI_{rsi_p}");
103 let rsi: &[f64] = rsi_out
104 .get(&rsi_key)
105 .ok_or_else(|| IndicatorError::InvalidParam("RSI output missing".into()))?;
106
107 let mut raw_k = vec![f64::NAN; n];
109 for i in (stoch_p - 1)..n {
110 let window = &rsi[(i + 1 - stoch_p)..=i];
112 if window.iter().any(|v| v.is_nan()) {
113 continue;
114 }
115 let min_r = window.iter().copied().fold(f64::INFINITY, f64::min);
116 let max_r = window.iter().copied().fold(f64::NEG_INFINITY, f64::max);
117 let range = max_r - min_r;
118 raw_k[i] = if range == 0.0 {
119 50.0
120 }
121 else {
123 100.0 * (rsi[i] - min_r) / range
124 };
125 }
126
127 let smooth_k = if ks <= 1 {
129 raw_k.clone()
130 } else {
131 sma_of(&raw_k, ks)
132 };
133
134 let d = sma_of(&smooth_k, dp);
136
137 Ok(IndicatorOutput::from_pairs([
138 ("StochRSI_K".to_string(), smooth_k),
139 ("StochRSI_D".to_string(), d),
140 ]))
141 }
142}
143
144fn sma_of(src: &[f64], period: usize) -> Vec<f64> {
145 let n = src.len();
146 let mut out = vec![f64::NAN; n];
147 let mut consecutive = 0usize;
148 for i in 0..n {
149 if src[i].is_nan() {
150 consecutive = 0;
151 } else {
152 consecutive += 1;
153 if consecutive >= period {
154 let sum: f64 = src[(i + 1 - period)..=i].iter().sum();
155 out[i] = sum / period as f64;
156 }
157 }
158 }
159 out
160}
161
162pub fn factory<S: ::std::hash::BuildHasher>(params: &HashMap<String, String, S>) -> Result<Box<dyn Indicator>, IndicatorError> {
165 Ok(Box::new(StochasticRsi::new(StochRsiParams {
166 rsi_period: param_usize(params, "rsi_period", 14)?,
167 stoch_period: param_usize(params, "stoch_period", 14)?,
168 k_smooth: param_usize(params, "k_smooth", 3)?,
169 d_period: param_usize(params, "d_period", 3)?,
170 })))
171}
172
173#[cfg(test)]
176mod tests {
177 use super::*;
178
179 fn make_candles(closes: &[f64]) -> Vec<Candle> {
180 closes
181 .iter()
182 .enumerate()
183 .map(|(i, &c)| Candle {
184 time: i64::try_from(i).expect("time index fits i64"),
185 open: c,
186 high: c,
187 low: c,
188 close: c,
189 volume: 1.0,
190 })
191 .collect()
192 }
193
194 #[test]
195 fn stochrsi_insufficient_data() {
196 let err = StochasticRsi::default()
197 .calculate(&make_candles(&[1.0; 10]))
198 .unwrap_err();
199 assert!(matches!(err, IndicatorError::InsufficientData { .. }));
200 }
201
202 #[test]
203 fn stochrsi_output_columns_exist() {
204 let needed = StochasticRsi::default().required_len();
205 let prices: Vec<f64> = (0..needed + 5)
206 .map(|i| 100.0 + (i as f64 * 0.4).sin() * 5.0)
207 .collect();
208 let out = StochasticRsi::default()
209 .calculate(&make_candles(&prices))
210 .unwrap();
211 assert!(out.get("StochRSI_K").is_some());
212 assert!(out.get("StochRSI_D").is_some());
213 }
214
215 #[test]
216 fn stochrsi_range_0_to_100() {
217 let needed = StochasticRsi::default().required_len();
218 let prices: Vec<f64> = (0..needed + 20)
219 .map(|i| 100.0 + (i as f64 * 0.25).sin() * 8.0)
220 .collect();
221 let out = StochasticRsi::default()
222 .calculate(&make_candles(&prices))
223 .unwrap();
224 for &v in out.get("StochRSI_K").unwrap() {
225 if !v.is_nan() {
226 assert!((0.0..=100.0).contains(&v), "K out of range: {v}");
227 }
228 }
229 for &v in out.get("StochRSI_D").unwrap() {
230 if !v.is_nan() {
231 assert!((0.0..=100.0).contains(&v), "D out of range: {v}");
232 }
233 }
234 }
235
236 #[test]
237 fn stochrsi_constant_prices_neutral() {
238 let needed = StochasticRsi::default().required_len();
240 let prices = vec![100.0_f64; needed + 5];
241 let out = StochasticRsi::default()
242 .calculate(&make_candles(&prices))
243 .unwrap();
244 let k = out.get("StochRSI_K").unwrap();
245 for &v in k.iter().filter(|v| !v.is_nan()) {
246 assert!((v - 50.0).abs() < 1e-9, "expected 50.0 (neutral), got {v}");
247 }
248 }
249
250 #[test]
251 fn stochrsi_d_lags_k() {
252 let needed = StochasticRsi::default().required_len();
254 let prices: Vec<f64> = (0..needed + 10)
255 .map(|i| 100.0 + (i as f64 * 0.5).sin() * 5.0)
256 .collect();
257 let out = StochasticRsi::default()
258 .calculate(&make_candles(&prices))
259 .unwrap();
260 let k_count = out
261 .get("StochRSI_K")
262 .unwrap()
263 .iter()
264 .filter(|v| !v.is_nan())
265 .count();
266 let d_count = out
267 .get("StochRSI_D")
268 .unwrap()
269 .iter()
270 .filter(|v| !v.is_nan())
271 .count();
272 assert!(d_count <= k_count, "D should have ≤ non-NaN values than K");
273 }
274
275 #[test]
276 fn factory_creates_stochrsi() {
277 let ind = factory(&HashMap::new()).unwrap();
278 assert_eq!(ind.name(), "StochasticRSI");
279 }
280}