use crate::indicators::utils::{validate_data_length, validate_period};
use crate::indicators::{Candle, Indicator, IndicatorError};
use std::collections::VecDeque;
#[derive(Debug)]
pub struct StochasticOscillator {
k_period: usize,
d_period: usize,
k_buffer: VecDeque<f64>,
}
impl StochasticOscillator {
pub fn new(k_period: usize, d_period: usize) -> Result<Self, IndicatorError> {
validate_period(k_period, 1)?;
validate_period(d_period, 1)?;
Ok(Self {
k_period,
d_period,
k_buffer: VecDeque::with_capacity(d_period),
})
}
fn calculate_k(candles: &[Candle], idx: usize, period: usize) -> f64 {
if idx < period - 1 {
return 50.0; }
let current_close = candles[idx].close;
let start_idx = idx.saturating_sub(period - 1);
let mut lowest_low = candles[start_idx].low;
let mut highest_high = candles[start_idx].high;
for candle in candles.iter().take(idx + 1).skip(start_idx + 1) {
lowest_low = lowest_low.min(candle.low);
highest_high = highest_high.max(candle.high);
}
if highest_high == lowest_low {
return 50.0; }
((current_close - lowest_low) / (highest_high - lowest_low)) * 100.0
}
}
#[derive(Debug, Clone, Copy)]
pub struct StochasticResult {
pub k: f64,
pub d: f64,
}
impl Indicator<Candle, StochasticResult> for StochasticOscillator {
fn calculate(&mut self, data: &[Candle]) -> Result<Vec<StochasticResult>, IndicatorError> {
validate_data_length(data, self.k_period + self.d_period - 1)?;
let n = data.len();
let mut result = Vec::with_capacity(n - self.k_period - self.d_period + 2);
self.reset();
let mut k_values = Vec::with_capacity(n);
for i in 0..n {
k_values.push(Self::calculate_k(data, i, self.k_period));
}
let k_start_idx = self.k_period - 1;
for (i, &k_value) in k_values.iter().enumerate().skip(k_start_idx) {
self.k_buffer.push_back(k_value);
if self.k_buffer.len() > self.d_period {
self.k_buffer.pop_front();
}
if self.k_buffer.len() == self.d_period {
let d = self.k_buffer.iter().sum::<f64>() / self.d_period as f64;
result.push(StochasticResult { k: k_values[i], d });
}
}
Ok(result)
}
fn next(&mut self, _value: Candle) -> Result<Option<StochasticResult>, IndicatorError> {
Err(IndicatorError::CalculationError(
"Real-time calculation of Stochastic Oscillator requires storing previous candles"
.to_string(),
))
}
fn reset(&mut self) {
self.k_buffer.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::indicators::Candle;
#[test]
fn test_stochastic_new() {
assert!(StochasticOscillator::new(14, 3).is_ok());
assert!(StochasticOscillator::new(0, 3).is_err());
assert!(StochasticOscillator::new(14, 0).is_err());
}
#[test]
fn test_stochastic_calculation() {
let mut stoch = StochasticOscillator::new(3, 2).unwrap();
let candles = vec![
Candle {
timestamp: 1,
open: 10.0,
high: 12.0,
low: 9.0,
close: 11.0,
volume: 1000.0,
},
Candle {
timestamp: 2,
open: 11.0,
high: 13.0,
low: 10.0,
close: 12.0,
volume: 1000.0,
},
Candle {
timestamp: 3,
open: 12.0,
high: 14.0,
low: 11.0,
close: 13.0,
volume: 1000.0,
},
Candle {
timestamp: 4,
open: 13.0,
high: 15.0,
low: 12.0,
close: 14.0,
volume: 1000.0,
},
Candle {
timestamp: 5,
open: 14.0,
high: 16.0,
low: 11.0,
close: 13.0,
volume: 1000.0,
},
];
let result = stoch.calculate(&candles).unwrap();
assert_eq!(result.len(), 2);
for stoch_result in &result {
assert!(stoch_result.k >= 0.0 && stoch_result.k <= 100.0);
assert!(stoch_result.d >= 0.0 && stoch_result.d <= 100.0);
}
}
#[test]
fn test_stochastic_next_error() {
let mut stoch = StochasticOscillator::new(14, 3).unwrap();
let candle = Candle {
timestamp: 1,
open: 10.0,
high: 12.0,
low: 9.0,
close: 11.0,
volume: 1000.0,
};
assert!(stoch.next(candle).is_err());
}
#[test]
fn test_stochastic_reset() {
let mut stoch = StochasticOscillator::new(14, 3).unwrap();
let candles = vec![
Candle {
timestamp: 1,
open: 10.0,
high: 12.0,
low: 9.0,
close: 11.0,
volume: 1000.0,
},
Candle {
timestamp: 2,
open: 11.0,
high: 13.0,
low: 10.0,
close: 12.0,
volume: 1000.0,
},
];
let _ = stoch.calculate(&candles);
stoch.reset();
}
}