1use std::collections::HashMap;
21
22use crate::error::IndicatorError;
23use crate::indicator::{Indicator, IndicatorOutput};
24use crate::registry::param_usize;
25use crate::types::Candle;
26
27#[derive(Debug, Clone)]
30pub struct StochParams {
31 pub k_period: usize,
33 pub smooth_k: usize,
35 pub d_period: usize,
37}
38
39impl Default for StochParams {
40 fn default() -> Self {
41 Self {
42 k_period: 14,
43 smooth_k: 3,
44 d_period: 3,
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
52pub struct Stochastic {
53 pub params: StochParams,
54}
55
56impl Stochastic {
57 pub fn new(params: StochParams) -> Self {
58 Self { params }
59 }
60}
61
62impl Default for Stochastic {
63 fn default() -> Self {
64 Self::new(StochParams::default())
65 }
66}
67
68impl Indicator for Stochastic {
69 fn name(&self) -> &'static str {
70 "Stochastic"
71 }
72
73 fn required_len(&self) -> usize {
74 self.params.k_period + self.params.smooth_k + self.params.d_period - 2
75 }
76
77 fn required_columns(&self) -> &[&'static str] {
78 &["high", "low", "close"]
79 }
80
81 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
82 self.check_len(candles)?;
83
84 let n = candles.len();
85 let kp = self.params.k_period;
86 let sk = self.params.smooth_k;
87 let dp = self.params.d_period;
88
89 let mut raw_k = vec![f64::NAN; n];
91 for i in (kp - 1)..n {
92 let window = &candles[(i + 1 - kp)..=i];
93 let hh = window
94 .iter()
95 .map(|c| c.high)
96 .fold(f64::NEG_INFINITY, f64::max);
97 let ll = window.iter().map(|c| c.low).fold(f64::INFINITY, f64::min);
98 let range = hh - ll;
99 raw_k[i] = if range == 0.0 {
100 f64::NAN
101 } else {
102 100.0 * (candles[i].close - ll) / range
103 };
104 }
105
106 let smooth_k = if sk <= 1 {
108 raw_k.clone()
109 } else {
110 sma_of(&raw_k, sk)
111 };
112
113 let d = sma_of(&smooth_k, dp);
115
116 Ok(IndicatorOutput::from_pairs([
117 ("Stoch_K".to_string(), smooth_k),
118 ("Stoch_D".to_string(), d),
119 ]))
120 }
121}
122
123fn sma_of(src: &[f64], period: usize) -> Vec<f64> {
130 let n = src.len();
131 let mut out = vec![f64::NAN; n];
132 let mut consecutive = 0usize;
134 for i in 0..n {
135 if src[i].is_nan() {
136 consecutive = 0;
137 } else {
138 consecutive += 1;
139 if consecutive >= period {
140 let sum: f64 = src[(i + 1 - period)..=i].iter().sum();
141 out[i] = sum / period as f64;
142 }
143 }
144 }
145 out
146}
147
148pub fn factory<S: ::std::hash::BuildHasher>(
151 params: &HashMap<String, String, S>,
152) -> Result<Box<dyn Indicator>, IndicatorError> {
153 Ok(Box::new(Stochastic::new(StochParams {
154 k_period: param_usize(params, "k_period", 14)?,
155 smooth_k: param_usize(params, "smooth_k", 3)?,
156 d_period: param_usize(params, "d_period", 3)?,
157 })))
158}
159
160#[cfg(test)]
163mod tests {
164 use super::*;
165
166 fn make_candles(data: &[(f64, f64, f64)]) -> Vec<Candle> {
167 data.iter()
169 .enumerate()
170 .map(|(i, &(h, l, c))| Candle {
171 time: i64::try_from(i).expect("time index fits i64"),
172 open: c,
173 high: h,
174 low: l,
175 close: c,
176 volume: 1.0,
177 })
178 .collect()
179 }
180
181 fn uniform_candles(n: usize, high: f64, low: f64, close: f64) -> Vec<Candle> {
182 make_candles(&vec![(high, low, close); n])
183 }
184
185 #[test]
186 fn stoch_insufficient_data() {
187 let err = Stochastic::default()
188 .calculate(&uniform_candles(5, 12.0, 8.0, 10.0))
189 .unwrap_err();
190 assert!(matches!(err, IndicatorError::InsufficientData { .. }));
191 }
192
193 #[test]
194 fn stoch_output_columns_exist() {
195 let out = Stochastic::default()
196 .calculate(&uniform_candles(30, 12.0, 8.0, 10.0))
197 .unwrap();
198 assert!(out.get("Stoch_K").is_some());
199 assert!(out.get("Stoch_D").is_some());
200 }
201
202 #[test]
203 fn stoch_known_value_midpoint() {
204 let out = Stochastic::new(StochParams {
208 k_period: 5,
209 smooth_k: 3,
210 d_period: 3,
211 })
212 .calculate(&uniform_candles(20, 12.0, 8.0, 10.0))
213 .unwrap();
214 let k = out.get("Stoch_K").unwrap();
215 let d = out.get("Stoch_D").unwrap();
216 let last_k = k.iter().rev().find(|v| !v.is_nan()).copied().unwrap();
217 let last_d = d.iter().rev().find(|v| !v.is_nan()).copied().unwrap();
218 assert!(
219 (last_k - 50.0).abs() < 1e-9,
220 "K expected 50.0, got {last_k}"
221 );
222 assert!(
223 (last_d - 50.0).abs() < 1e-9,
224 "D expected 50.0, got {last_d}"
225 );
226 }
227
228 #[test]
229 fn stoch_close_at_high_is_100() {
230 let out = Stochastic::new(StochParams {
232 k_period: 5,
233 smooth_k: 1,
234 d_period: 1,
235 })
236 .calculate(&uniform_candles(10, 12.0, 8.0, 12.0))
237 .unwrap();
238 let k = out.get("Stoch_K").unwrap();
239 for &v in k.iter().filter(|v| !v.is_nan()) {
240 assert!((v - 100.0).abs() < 1e-9, "expected 100.0, got {v}");
241 }
242 }
243
244 #[test]
245 fn stoch_close_at_low_is_0() {
246 let out = Stochastic::new(StochParams {
248 k_period: 5,
249 smooth_k: 1,
250 d_period: 1,
251 })
252 .calculate(&uniform_candles(10, 12.0, 8.0, 8.0))
253 .unwrap();
254 let k = out.get("Stoch_K").unwrap();
255 for &v in k.iter().filter(|v| !v.is_nan()) {
256 assert!(v.abs() < 1e-9, "expected 0.0, got {v}");
257 }
258 }
259
260 #[test]
261 fn stoch_range_0_to_100() {
262 let mut data = vec![];
264 for i in 0..15 {
265 let f = i as f64;
266 data.push((f + 1.0, f - 1.0, f));
267 }
268 for i in (0..10).rev() {
269 let f = i as f64;
270 data.push((f + 1.0, f - 1.0, f));
271 }
272 let out = Stochastic::default()
273 .calculate(&make_candles(&data))
274 .unwrap();
275 for &v in out.get("Stoch_K").unwrap() {
276 if !v.is_nan() {
277 assert!((0.0..=100.0).contains(&v), "K out of range: {v}");
278 }
279 }
280 for &v in out.get("Stoch_D").unwrap() {
281 if !v.is_nan() {
282 assert!((0.0..=100.0).contains(&v), "D out of range: {v}");
283 }
284 }
285 }
286
287 #[test]
288 fn stoch_no_smoothing_fast_stochastic() {
289 let out = Stochastic::new(StochParams {
291 k_period: 3,
292 smooth_k: 1,
293 d_period: 1,
294 })
295 .calculate(&uniform_candles(10, 10.0, 0.0, 6.0))
296 .unwrap();
297 let k = out.get("Stoch_K").unwrap();
299 for &v in k.iter().filter(|v| !v.is_nan()) {
300 assert!((v - 60.0).abs() < 1e-9, "expected 60.0, got {v}");
301 }
302 }
303
304 #[test]
305 fn factory_creates_stochastic() {
306 let ind = factory(&HashMap::new()).unwrap();
307 assert_eq!(ind.name(), "Stochastic");
308 }
309}