indicators/momentum/
stochastic_rsi.rs1use std::collections::HashMap;
16
17use crate::error::IndicatorError;
18use crate::indicator::{Indicator, IndicatorOutput};
19use crate::momentum::rsi::{Rsi, RsiParams};
20use crate::registry::param_usize;
21use crate::types::Candle;
22
23#[derive(Debug, Clone)]
26pub struct StochRsiParams {
27 pub rsi_period: usize,
29 pub stoch_period: usize,
31 pub k_smooth: usize,
33 pub d_period: usize,
35}
36
37impl Default for StochRsiParams {
38 fn default() -> Self {
39 Self {
40 rsi_period: 14,
41 stoch_period: 14,
42 k_smooth: 3,
43 d_period: 3,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
51pub struct StochasticRsi {
52 pub params: StochRsiParams,
53}
54
55impl StochasticRsi {
56 pub fn new(params: StochRsiParams) -> Self {
57 Self { params }
58 }
59}
60
61impl Default for StochasticRsi {
62 fn default() -> Self {
63 Self::new(StochRsiParams::default())
64 }
65}
66
67impl Indicator for StochasticRsi {
68 fn name(&self) -> &'static str {
69 "StochasticRSI"
70 }
71
72 fn required_len(&self) -> usize {
73 self.params.rsi_period
76 + 1
77 + self.params.stoch_period
78 + self.params.k_smooth
79 + self.params.d_period
80 - 2
81 }
82
83 fn required_columns(&self) -> &[&'static str] {
84 &["close"]
85 }
86
87 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
88 self.check_len(candles)?;
89
90 let n = candles.len();
91 let rsi_p = self.params.rsi_period;
92 let stoch_p = self.params.stoch_period;
93 let ks = self.params.k_smooth;
94 let dp = self.params.d_period;
95
96 let rsi_out = Rsi::new(RsiParams {
98 period: rsi_p,
99 ..Default::default()
100 })
101 .calculate(candles)?;
102 let rsi_key = format!("RSI_{rsi_p}");
103 let rsi: &[f64] = rsi_out
104 .get(&rsi_key)
105 .ok_or_else(|| IndicatorError::InvalidParam("RSI output missing".into()))?;
106
107 let mut raw_k = vec![f64::NAN; n];
109 for i in (stoch_p - 1)..n {
110 let window = &rsi[(i + 1 - stoch_p)..=i];
112 if window.iter().any(|v| v.is_nan()) {
113 continue;
114 }
115 let min_r = window.iter().copied().fold(f64::INFINITY, f64::min);
116 let max_r = window.iter().copied().fold(f64::NEG_INFINITY, f64::max);
117 let range = max_r - min_r;
118 raw_k[i] = if range == 0.0 {
119 50.0
120 }
121 else {
123 100.0 * (rsi[i] - min_r) / range
124 };
125 }
126
127 let smooth_k = if ks <= 1 {
129 raw_k.clone()
130 } else {
131 sma_of(&raw_k, ks)
132 };
133
134 let d = sma_of(&smooth_k, dp);
136
137 Ok(IndicatorOutput::from_pairs([
138 ("StochRSI_K".to_string(), smooth_k),
139 ("StochRSI_D".to_string(), d),
140 ]))
141 }
142}
143
144fn sma_of(src: &[f64], period: usize) -> Vec<f64> {
145 let n = src.len();
146 let mut out = vec![f64::NAN; n];
147 let mut consecutive = 0usize;
148 for i in 0..n {
149 if src[i].is_nan() {
150 consecutive = 0;
151 } else {
152 consecutive += 1;
153 if consecutive >= period {
154 let sum: f64 = src[(i + 1 - period)..=i].iter().sum();
155 out[i] = sum / period as f64;
156 }
157 }
158 }
159 out
160}
161
162pub fn factory<S: ::std::hash::BuildHasher>(
165 params: &HashMap<String, String, S>,
166) -> Result<Box<dyn Indicator>, IndicatorError> {
167 Ok(Box::new(StochasticRsi::new(StochRsiParams {
168 rsi_period: param_usize(params, "rsi_period", 14)?,
169 stoch_period: param_usize(params, "stoch_period", 14)?,
170 k_smooth: param_usize(params, "k_smooth", 3)?,
171 d_period: param_usize(params, "d_period", 3)?,
172 })))
173}
174
175#[cfg(test)]
178mod tests {
179 use super::*;
180
181 fn make_candles(closes: &[f64]) -> Vec<Candle> {
182 closes
183 .iter()
184 .enumerate()
185 .map(|(i, &c)| Candle {
186 time: i64::try_from(i).expect("time index fits i64"),
187 open: c,
188 high: c,
189 low: c,
190 close: c,
191 volume: 1.0,
192 })
193 .collect()
194 }
195
196 #[test]
197 fn stochrsi_insufficient_data() {
198 let err = StochasticRsi::default()
199 .calculate(&make_candles(&[1.0; 10]))
200 .unwrap_err();
201 assert!(matches!(err, IndicatorError::InsufficientData { .. }));
202 }
203
204 #[test]
205 fn stochrsi_output_columns_exist() {
206 let needed = StochasticRsi::default().required_len();
207 let prices: Vec<f64> = (0..needed + 5)
208 .map(|i| 100.0 + (i as f64 * 0.4).sin() * 5.0)
209 .collect();
210 let out = StochasticRsi::default()
211 .calculate(&make_candles(&prices))
212 .unwrap();
213 assert!(out.get("StochRSI_K").is_some());
214 assert!(out.get("StochRSI_D").is_some());
215 }
216
217 #[test]
218 fn stochrsi_range_0_to_100() {
219 let needed = StochasticRsi::default().required_len();
220 let prices: Vec<f64> = (0..needed + 20)
221 .map(|i| 100.0 + (i as f64 * 0.25).sin() * 8.0)
222 .collect();
223 let out = StochasticRsi::default()
224 .calculate(&make_candles(&prices))
225 .unwrap();
226 for &v in out.get("StochRSI_K").unwrap() {
227 if !v.is_nan() {
228 assert!((0.0..=100.0).contains(&v), "K out of range: {v}");
229 }
230 }
231 for &v in out.get("StochRSI_D").unwrap() {
232 if !v.is_nan() {
233 assert!((0.0..=100.0).contains(&v), "D out of range: {v}");
234 }
235 }
236 }
237
238 #[test]
239 fn stochrsi_constant_prices_neutral() {
240 let needed = StochasticRsi::default().required_len();
242 let prices = vec![100.0_f64; needed + 5];
243 let out = StochasticRsi::default()
244 .calculate(&make_candles(&prices))
245 .unwrap();
246 let k = out.get("StochRSI_K").unwrap();
247 for &v in k.iter().filter(|v| !v.is_nan()) {
248 assert!((v - 50.0).abs() < 1e-9, "expected 50.0 (neutral), got {v}");
249 }
250 }
251
252 #[test]
253 fn stochrsi_d_lags_k() {
254 let needed = StochasticRsi::default().required_len();
256 let prices: Vec<f64> = (0..needed + 10)
257 .map(|i| 100.0 + (i as f64 * 0.5).sin() * 5.0)
258 .collect();
259 let out = StochasticRsi::default()
260 .calculate(&make_candles(&prices))
261 .unwrap();
262 let k_count = out
263 .get("StochRSI_K")
264 .unwrap()
265 .iter()
266 .filter(|v| !v.is_nan())
267 .count();
268 let d_count = out
269 .get("StochRSI_D")
270 .unwrap()
271 .iter()
272 .filter(|v| !v.is_nan())
273 .count();
274 assert!(d_count <= k_count, "D should have ≤ non-NaN values than K");
275 }
276
277 #[test]
278 fn factory_creates_stochrsi() {
279 let ind = factory(&HashMap::new()).unwrap();
280 assert_eq!(ind.name(), "StochasticRSI");
281 }
282}