Skip to main content

finance_query/indicators/
stochastic.rs

1//! Stochastic Oscillator indicator.
2
3use super::{IndicatorError, Result, sma::sma};
4use serde::{Deserialize, Serialize};
5
6/// Result of Stochastic Oscillator calculation
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
8pub struct StochasticResult {
9    /// %K line
10    pub k: Vec<Option<f64>>,
11    /// %D line (Signal line)
12    pub d: Vec<Option<f64>>,
13}
14
15/// Calculate Stochastic Oscillator.
16///
17/// Returns (%K, %D) where:
18/// %K = (Close - Lowest Low) / (Highest High - Lowest Low) * 100
19/// %D = SMA of %K
20///
21/// # Arguments
22///
23/// * `highs` - High prices
24/// * `lows` - Low prices
25/// * `closes` - Close prices
26/// * `k_period` - Period for %K
27/// * `d_period` - Period for %D (SMA of %K)
28///
29/// # Example
30///
31/// ```
32/// use finance_query::indicators::stochastic;
33///
34/// let highs = vec![10.0, 11.0, 12.0, 13.0, 14.0];
35/// let lows = vec![8.0, 9.0, 10.0, 11.0, 12.0];
36/// let closes = vec![9.0, 10.0, 11.0, 12.0, 13.0];
37/// let result = stochastic(&highs, &lows, &closes, 3, 2).unwrap();
38/// ```
39pub fn stochastic(
40    highs: &[f64],
41    lows: &[f64],
42    closes: &[f64],
43    k_period: usize,
44    d_period: usize,
45) -> Result<StochasticResult> {
46    if k_period == 0 || d_period == 0 {
47        return Err(IndicatorError::InvalidPeriod(
48            "Periods must be greater than 0".to_string(),
49        ));
50    }
51    let len = highs.len();
52    if lows.len() != len || closes.len() != len {
53        return Err(IndicatorError::InvalidPeriod(
54            "Data lengths must match".to_string(),
55        ));
56    }
57    if len < k_period {
58        return Err(IndicatorError::InsufficientData {
59            need: k_period,
60            got: len,
61        });
62    }
63
64    let mut k_values = vec![None; len];
65    let mut k_series_for_sma = vec![0.0; len];
66
67    // Calculate %K values
68    for i in (k_period - 1)..len {
69        let start_idx = i + 1 - k_period;
70        let end_idx = i;
71
72        let period_highs = &highs[start_idx..=end_idx];
73        let period_lows = &lows[start_idx..=end_idx];
74
75        let highest = period_highs
76            .iter()
77            .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
78        let lowest = period_lows.iter().fold(f64::INFINITY, |a, &b| a.min(b));
79
80        let range = highest - lowest;
81        let k = if range == 0.0 {
82            50.0 // Neutral when no range
83        } else {
84            ((closes[end_idx] - lowest) / range) * 100.0
85        };
86
87        k_values[i] = Some(k);
88        k_series_for_sma[i] = k;
89    }
90
91    // Calculate %D (SMA of %K)
92    let valid_k_start = k_period - 1;
93    if len <= valid_k_start {
94        return Ok(StochasticResult {
95            k: k_values,
96            d: vec![None; len],
97        });
98    }
99
100    let valid_k_slice = &k_series_for_sma[valid_k_start..];
101
102    let d_values_valid = sma(valid_k_slice, d_period);
103
104    let mut d_values = vec![None; len];
105    for (j, val) in d_values_valid.into_iter().enumerate() {
106        let original_idx = j + valid_k_start;
107        if original_idx < len {
108            d_values[original_idx] = val;
109        }
110    }
111
112    Ok(StochasticResult {
113        k: k_values,
114        d: d_values,
115    })
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn test_stochastic() {
124        let highs = vec![10.0, 11.0, 12.0, 13.0, 14.0];
125        let lows = vec![8.0, 9.0, 10.0, 11.0, 12.0];
126        let closes = vec![9.0, 10.0, 11.0, 12.0, 13.0];
127        let result = stochastic(&highs, &lows, &closes, 3, 2).unwrap();
128
129        assert_eq!(result.k.len(), 5);
130        assert_eq!(result.d.len(), 5);
131
132        // k valid from index 2
133        assert!(result.k[0].is_none());
134        assert!(result.k[1].is_none());
135        assert!(result.k[2].is_some());
136
137        // d valid from index 2 + (2-1) = 3
138        assert!(result.d[0].is_none());
139        assert!(result.d[1].is_none());
140        assert!(result.d[2].is_none());
141        assert!(result.d[3].is_some());
142    }
143}