1use std::collections::HashMap;
21
22use crate::error::IndicatorError;
23use crate::indicator::{Indicator, IndicatorOutput};
24use crate::registry::param_usize;
25use crate::types::Candle;
26
27#[derive(Debug, Clone)]
30pub struct StochParams {
31 pub k_period: usize,
33 pub smooth_k: usize,
35 pub d_period: usize,
37}
38
39impl Default for StochParams {
40 fn default() -> Self {
41 Self {
42 k_period: 14,
43 smooth_k: 3,
44 d_period: 3,
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
52pub struct Stochastic {
53 pub params: StochParams,
54}
55
56impl Stochastic {
57 pub fn new(params: StochParams) -> Self {
58 Self { params }
59 }
60 pub fn default() -> Self {
61 Self::new(StochParams::default())
62 }
63}
64
65impl Indicator for Stochastic {
66 fn name(&self) -> &str {
67 "Stochastic"
68 }
69
70 fn required_len(&self) -> usize {
71 self.params.k_period + self.params.smooth_k + self.params.d_period - 2
72 }
73
74 fn required_columns(&self) -> &[&'static str] {
75 &["high", "low", "close"]
76 }
77
78 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
79 self.check_len(candles)?;
80
81 let n = candles.len();
82 let kp = self.params.k_period;
83 let sk = self.params.smooth_k;
84 let dp = self.params.d_period;
85
86 let mut raw_k = vec![f64::NAN; n];
88 for i in (kp - 1)..n {
89 let window = &candles[(i + 1 - kp)..=i];
90 let hh = window.iter().map(|c| c.high).fold(f64::NEG_INFINITY, f64::max);
91 let ll = window.iter().map(|c| c.low).fold(f64::INFINITY, f64::min);
92 let range = hh - ll;
93 raw_k[i] = if range == 0.0 {
94 f64::NAN
95 } else {
96 100.0 * (candles[i].close - ll) / range
97 };
98 }
99
100 let smooth_k = if sk <= 1 {
102 raw_k.clone()
103 } else {
104 sma_of(&raw_k, sk)
105 };
106
107 let d = sma_of(&smooth_k, dp);
109
110 Ok(IndicatorOutput::from_pairs([
111 ("Stoch_K".to_string(), smooth_k),
112 ("Stoch_D".to_string(), d),
113 ]))
114 }
115}
116
117fn sma_of(src: &[f64], period: usize) -> Vec<f64> {
120 let n = src.len();
121 let mut out = vec![f64::NAN; n];
122 let mut consecutive = 0usize;
124 for i in 0..n {
125 if src[i].is_nan() {
126 consecutive = 0;
127 } else {
128 consecutive += 1;
129 if consecutive >= period {
130 let sum: f64 = src[(i + 1 - period)..=i].iter().sum();
131 out[i] = sum / period as f64;
132 }
133 }
134 }
135 out
136}
137
138pub fn factory(params: &HashMap<String, String>) -> Result<Box<dyn Indicator>, IndicatorError> {
141 Ok(Box::new(Stochastic::new(StochParams {
142 k_period: param_usize(params, "k_period", 14)?,
143 smooth_k: param_usize(params, "smooth_k", 3)?,
144 d_period: param_usize(params, "d_period", 3)?,
145 })))
146}
147
148#[cfg(test)]
151mod tests {
152 use super::*;
153
154 fn make_candles(data: &[(f64, f64, f64)]) -> Vec<Candle> {
155 data.iter().enumerate().map(|(i, &(h, l, c))| Candle {
157 time: i as i64, open: c, high: h, low: l, close: c, volume: 1.0,
158 }).collect()
159 }
160
161 fn uniform_candles(n: usize, high: f64, low: f64, close: f64) -> Vec<Candle> {
162 make_candles(&vec![(high, low, close); n])
163 }
164
165 #[test]
166 fn stoch_insufficient_data() {
167 let err = Stochastic::default()
168 .calculate(&uniform_candles(5, 12.0, 8.0, 10.0))
169 .unwrap_err();
170 assert!(matches!(err, IndicatorError::InsufficientData { .. }));
171 }
172
173 #[test]
174 fn stoch_output_columns_exist() {
175 let out = Stochastic::default()
176 .calculate(&uniform_candles(30, 12.0, 8.0, 10.0))
177 .unwrap();
178 assert!(out.get("Stoch_K").is_some());
179 assert!(out.get("Stoch_D").is_some());
180 }
181
182 #[test]
183 fn stoch_known_value_midpoint() {
184 let out = Stochastic::new(StochParams { k_period: 5, smooth_k: 3, d_period: 3 })
188 .calculate(&uniform_candles(20, 12.0, 8.0, 10.0))
189 .unwrap();
190 let k = out.get("Stoch_K").unwrap();
191 let d = out.get("Stoch_D").unwrap();
192 let last_k = k.iter().rev().find(|v| !v.is_nan()).copied().unwrap();
193 let last_d = d.iter().rev().find(|v| !v.is_nan()).copied().unwrap();
194 assert!(
195 (last_k - 50.0).abs() < 1e-9,
196 "K expected 50.0, got {last_k}"
197 );
198 assert!(
199 (last_d - 50.0).abs() < 1e-9,
200 "D expected 50.0, got {last_d}"
201 );
202 }
203
204 #[test]
205 fn stoch_close_at_high_is_100() {
206 let out = Stochastic::new(StochParams { k_period: 5, smooth_k: 1, d_period: 1 })
208 .calculate(&uniform_candles(10, 12.0, 8.0, 12.0))
209 .unwrap();
210 let k = out.get("Stoch_K").unwrap();
211 for &v in k.iter().filter(|v| !v.is_nan()) {
212 assert!((v - 100.0).abs() < 1e-9, "expected 100.0, got {v}");
213 }
214 }
215
216 #[test]
217 fn stoch_close_at_low_is_0() {
218 let out = Stochastic::new(StochParams { k_period: 5, smooth_k: 1, d_period: 1 })
220 .calculate(&uniform_candles(10, 12.0, 8.0, 8.0))
221 .unwrap();
222 let k = out.get("Stoch_K").unwrap();
223 for &v in k.iter().filter(|v| !v.is_nan()) {
224 assert!(v.abs() < 1e-9, "expected 0.0, got {v}");
225 }
226 }
227
228 #[test]
229 fn stoch_range_0_to_100() {
230 let mut data = vec![];
232 for i in 0..15 {
233 let f = i as f64;
234 data.push((f + 1.0, f - 1.0, f));
235 }
236 for i in (0..10).rev() {
237 let f = i as f64;
238 data.push((f + 1.0, f - 1.0, f));
239 }
240 let out = Stochastic::default().calculate(&make_candles(&data)).unwrap();
241 for &v in out.get("Stoch_K").unwrap() {
242 if !v.is_nan() {
243 assert!(v >= 0.0 && v <= 100.0, "K out of range: {v}");
244 }
245 }
246 for &v in out.get("Stoch_D").unwrap() {
247 if !v.is_nan() {
248 assert!(v >= 0.0 && v <= 100.0, "D out of range: {v}");
249 }
250 }
251 }
252
253 #[test]
254 fn stoch_no_smoothing_fast_stochastic() {
255 let out = Stochastic::new(StochParams { k_period: 3, smooth_k: 1, d_period: 1 })
257 .calculate(&uniform_candles(10, 10.0, 0.0, 6.0))
258 .unwrap();
259 let k = out.get("Stoch_K").unwrap();
261 for &v in k.iter().filter(|v| !v.is_nan()) {
262 assert!((v - 60.0).abs() < 1e-9, "expected 60.0, got {v}");
263 }
264 }
265
266 #[test]
267 fn factory_creates_stochastic() {
268 let ind = factory(&HashMap::new()).unwrap();
269 assert_eq!(ind.name(), "Stochastic");
270 }
271}