Skip to main content

indicators/momentum/
stochastic.rs

1//! Stochastic Oscillator (%K and %D).
2//!
3//! Python source: `indicators/momentum/stochastic.py :: class Stochastic`
4//!
5//! # Algorithm
6//!
7//! 1. **Raw %K**:
8//!    `%K[i] = 100 * (close[i] - lowest_low) / (highest_high - lowest_low)`
9//!    where the window is `k_period` bars ending at `i`.
10//!    Yields `NaN` when `highest_high == lowest_low`.
11//!
12//! 2. **Smooth %K** (optional): SMA of raw %K over `smooth_k` bars.
13//!    `smooth_k = 1` means no smoothing (fast stochastic).
14//!    `smooth_k = 3` is the standard slow stochastic.
15//!
16//! 3. **%D**: SMA of smooth %K over `d_period` bars.
17//!
18//! Output columns: `"Stoch_K"`, `"Stoch_D"`.
19
20use std::collections::HashMap;
21
22use crate::error::IndicatorError;
23use crate::indicator::{Indicator, IndicatorOutput};
24use crate::registry::param_usize;
25use crate::types::Candle;
26
27// ── Params ────────────────────────────────────────────────────────────────────
28
29#[derive(Debug, Clone)]
30pub struct StochParams {
31    /// Look-back window for highest-high / lowest-low. Default: 14.
32    pub k_period: usize,
33    /// Smoothing of raw %K. 1 = no smoothing. Default: 3.
34    pub smooth_k: usize,
35    /// SMA period for %D. Default: 3.
36    pub d_period: usize,
37}
38
39impl Default for StochParams {
40    fn default() -> Self {
41        Self {
42            k_period: 14,
43            smooth_k: 3,
44            d_period: 3,
45        }
46    }
47}
48
49// ── Indicator struct ──────────────────────────────────────────────────────────
50
51#[derive(Debug, Clone)]
52pub struct Stochastic {
53    pub params: StochParams,
54}
55
56impl Stochastic {
57    pub fn new(params: StochParams) -> Self {
58        Self { params }
59    }
60}
61
62impl Default for Stochastic {
63    fn default() -> Self {
64        Self::new(StochParams::default())
65    }
66}
67
68impl Indicator for Stochastic {
69    fn name(&self) -> &'static str {
70        "Stochastic"
71    }
72
73    fn required_len(&self) -> usize {
74        self.params.k_period + self.params.smooth_k + self.params.d_period - 2
75    }
76
77    fn required_columns(&self) -> &[&'static str] {
78        &["high", "low", "close"]
79    }
80
81    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
82        self.check_len(candles)?;
83
84        let n = candles.len();
85        let kp = self.params.k_period;
86        let sk = self.params.smooth_k;
87        let dp = self.params.d_period;
88
89        // ── Step 1: raw %K ────────────────────────────────────────────────────
90        let mut raw_k = vec![f64::NAN; n];
91        for i in (kp - 1)..n {
92            let window = &candles[(i + 1 - kp)..=i];
93            let hh = window
94                .iter()
95                .map(|c| c.high)
96                .fold(f64::NEG_INFINITY, f64::max);
97            let ll = window.iter().map(|c| c.low).fold(f64::INFINITY, f64::min);
98            let range = hh - ll;
99            raw_k[i] = if range == 0.0 {
100                f64::NAN
101            } else {
102                100.0 * (candles[i].close - ll) / range
103            };
104        }
105
106        // ── Step 2: smooth %K (SMA) ───────────────────────────────────────────
107        let smooth_k = if sk <= 1 {
108            raw_k.clone()
109        } else {
110            sma_of(&raw_k, sk)
111        };
112
113        // ── Step 3: %D (SMA of smooth_k) ─────────────────────────────────────
114        let d = sma_of(&smooth_k, dp);
115
116        Ok(IndicatorOutput::from_pairs([
117            ("Stoch_K".to_string(), smooth_k),
118            ("Stoch_D".to_string(), d),
119        ]))
120    }
121}
122
123/// Rolling SMA over a `Vec<f64>` that may contain leading or mid-series NaN values.
124///
125/// `consecutive` counts the unbroken run of non-NaN values ending at the current
126/// position.  A window is only emitted once `consecutive >= period`, which
127/// guarantees that every element of the slice `src[(i+1-period)..=i]` is
128/// non-NaN — no special NaN-sum guard is needed in the inner loop.
129fn sma_of(src: &[f64], period: usize) -> Vec<f64> {
130    let n = src.len();
131    let mut out = vec![f64::NAN; n];
132    // Find the first index where `period` consecutive non-NaN values end.
133    let mut consecutive = 0usize;
134    for i in 0..n {
135        if src[i].is_nan() {
136            consecutive = 0;
137        } else {
138            consecutive += 1;
139            if consecutive >= period {
140                let sum: f64 = src[(i + 1 - period)..=i].iter().sum();
141                out[i] = sum / period as f64;
142            }
143        }
144    }
145    out
146}
147
148// ── Registry factory ──────────────────────────────────────────────────────────
149
150pub fn factory<S: ::std::hash::BuildHasher>(
151    params: &HashMap<String, String, S>,
152) -> Result<Box<dyn Indicator>, IndicatorError> {
153    Ok(Box::new(Stochastic::new(StochParams {
154        k_period: param_usize(params, "k_period", 14)?,
155        smooth_k: param_usize(params, "smooth_k", 3)?,
156        d_period: param_usize(params, "d_period", 3)?,
157    })))
158}
159
160// ── Tests ─────────────────────────────────────────────────────────────────────
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    fn make_candles(data: &[(f64, f64, f64)]) -> Vec<Candle> {
167        // (high, low, close)
168        data.iter()
169            .enumerate()
170            .map(|(i, &(h, l, c))| Candle {
171                time: i64::try_from(i).expect("time index fits i64"),
172                open: c,
173                high: h,
174                low: l,
175                close: c,
176                volume: 1.0,
177            })
178            .collect()
179    }
180
181    fn uniform_candles(n: usize, high: f64, low: f64, close: f64) -> Vec<Candle> {
182        make_candles(&vec![(high, low, close); n])
183    }
184
185    #[test]
186    fn stoch_insufficient_data() {
187        let err = Stochastic::default()
188            .calculate(&uniform_candles(5, 12.0, 8.0, 10.0))
189            .unwrap_err();
190        assert!(matches!(err, IndicatorError::InsufficientData { .. }));
191    }
192
193    #[test]
194    fn stoch_output_columns_exist() {
195        let out = Stochastic::default()
196            .calculate(&uniform_candles(30, 12.0, 8.0, 10.0))
197            .unwrap();
198        assert!(out.get("Stoch_K").is_some());
199        assert!(out.get("Stoch_D").is_some());
200    }
201
202    #[test]
203    fn stoch_known_value_midpoint() {
204        // high=12, low=8, close=10 for all bars.
205        // raw %K = 100*(10-8)/(12-8) = 50.0.
206        // smooth_k=3 SMA of [50,50,50,...] = 50. %D = 50.
207        let out = Stochastic::new(StochParams {
208            k_period: 5,
209            smooth_k: 3,
210            d_period: 3,
211        })
212        .calculate(&uniform_candles(20, 12.0, 8.0, 10.0))
213        .unwrap();
214        let k = out.get("Stoch_K").unwrap();
215        let d = out.get("Stoch_D").unwrap();
216        let last_k = k.iter().rev().find(|v| !v.is_nan()).copied().unwrap();
217        let last_d = d.iter().rev().find(|v| !v.is_nan()).copied().unwrap();
218        assert!(
219            (last_k - 50.0).abs() < 1e-9,
220            "K expected 50.0, got {last_k}"
221        );
222        assert!(
223            (last_d - 50.0).abs() < 1e-9,
224            "D expected 50.0, got {last_d}"
225        );
226    }
227
228    #[test]
229    fn stoch_close_at_high_is_100() {
230        // close == high → raw %K = 100.
231        let out = Stochastic::new(StochParams {
232            k_period: 5,
233            smooth_k: 1,
234            d_period: 1,
235        })
236        .calculate(&uniform_candles(10, 12.0, 8.0, 12.0))
237        .unwrap();
238        let k = out.get("Stoch_K").unwrap();
239        for &v in k.iter().filter(|v| !v.is_nan()) {
240            assert!((v - 100.0).abs() < 1e-9, "expected 100.0, got {v}");
241        }
242    }
243
244    #[test]
245    fn stoch_close_at_low_is_0() {
246        // close == low → raw %K = 0.
247        let out = Stochastic::new(StochParams {
248            k_period: 5,
249            smooth_k: 1,
250            d_period: 1,
251        })
252        .calculate(&uniform_candles(10, 12.0, 8.0, 8.0))
253        .unwrap();
254        let k = out.get("Stoch_K").unwrap();
255        for &v in k.iter().filter(|v| !v.is_nan()) {
256            assert!(v.abs() < 1e-9, "expected 0.0, got {v}");
257        }
258    }
259
260    #[test]
261    fn stoch_range_0_to_100() {
262        // Rising then falling sequence.
263        let mut data = vec![];
264        for i in 0..15 {
265            let f = i as f64;
266            data.push((f + 1.0, f - 1.0, f));
267        }
268        for i in (0..10).rev() {
269            let f = i as f64;
270            data.push((f + 1.0, f - 1.0, f));
271        }
272        let out = Stochastic::default()
273            .calculate(&make_candles(&data))
274            .unwrap();
275        for &v in out.get("Stoch_K").unwrap() {
276            if !v.is_nan() {
277                assert!((0.0..=100.0).contains(&v), "K out of range: {v}");
278            }
279        }
280        for &v in out.get("Stoch_D").unwrap() {
281            if !v.is_nan() {
282                assert!((0.0..=100.0).contains(&v), "D out of range: {v}");
283            }
284        }
285    }
286
287    #[test]
288    fn stoch_no_smoothing_fast_stochastic() {
289        // smooth_k=1 → raw %K passed through directly.
290        let out = Stochastic::new(StochParams {
291            k_period: 3,
292            smooth_k: 1,
293            d_period: 1,
294        })
295        .calculate(&uniform_candles(10, 10.0, 0.0, 6.0))
296        .unwrap();
297        // close=6, range=10 → 60.0.
298        let k = out.get("Stoch_K").unwrap();
299        for &v in k.iter().filter(|v| !v.is_nan()) {
300            assert!((v - 60.0).abs() < 1e-9, "expected 60.0, got {v}");
301        }
302    }
303
304    #[test]
305    fn factory_creates_stochastic() {
306        let ind = factory(&HashMap::new()).unwrap();
307        assert_eq!(ind.name(), "Stochastic");
308    }
309}