use std::collections::HashMap;
use crate::error::IndicatorError;
use crate::indicator::{Indicator, IndicatorOutput};
use crate::registry::param_usize;
use crate::types::Candle;
#[derive(Debug, Clone)]
pub struct StochParams {
pub k_period: usize,
pub smooth_k: usize,
pub d_period: usize,
}
impl Default for StochParams {
fn default() -> Self {
Self {
k_period: 14,
smooth_k: 3,
d_period: 3,
}
}
}
#[derive(Debug, Clone)]
pub struct Stochastic {
pub params: StochParams,
}
impl Stochastic {
pub fn new(params: StochParams) -> Self {
Self { params }
}
}
impl Default for Stochastic {
fn default() -> Self {
Self::new(StochParams::default())
}
}
impl Indicator for Stochastic {
fn name(&self) -> &'static str {
"Stochastic"
}
fn required_len(&self) -> usize {
self.params.k_period + self.params.smooth_k + self.params.d_period - 2
}
fn required_columns(&self) -> &[&'static str] {
&["high", "low", "close"]
}
fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
self.check_len(candles)?;
let n = candles.len();
let kp = self.params.k_period;
let sk = self.params.smooth_k;
let dp = self.params.d_period;
let mut raw_k = vec![f64::NAN; n];
for i in (kp - 1)..n {
let window = &candles[(i + 1 - kp)..=i];
let hh = window
.iter()
.map(|c| c.high)
.fold(f64::NEG_INFINITY, f64::max);
let ll = window.iter().map(|c| c.low).fold(f64::INFINITY, f64::min);
let range = hh - ll;
raw_k[i] = if range == 0.0 {
f64::NAN
} else {
100.0 * (candles[i].close - ll) / range
};
}
let smooth_k = if sk <= 1 {
raw_k.clone()
} else {
sma_of(&raw_k, sk)
};
let d = sma_of(&smooth_k, dp);
Ok(IndicatorOutput::from_pairs([
("Stoch_K".to_string(), smooth_k),
("Stoch_D".to_string(), d),
]))
}
}
fn sma_of(src: &[f64], period: usize) -> Vec<f64> {
let n = src.len();
let mut out = vec![f64::NAN; n];
let mut consecutive = 0usize;
for i in 0..n {
if src[i].is_nan() {
consecutive = 0;
} else {
consecutive += 1;
if consecutive >= period {
let sum: f64 = src[(i + 1 - period)..=i].iter().sum();
out[i] = sum / period as f64;
}
}
}
out
}
pub fn factory<S: ::std::hash::BuildHasher>(params: &HashMap<String, String, S>) -> Result<Box<dyn Indicator>, IndicatorError> {
Ok(Box::new(Stochastic::new(StochParams {
k_period: param_usize(params, "k_period", 14)?,
smooth_k: param_usize(params, "smooth_k", 3)?,
d_period: param_usize(params, "d_period", 3)?,
})))
}
#[cfg(test)]
mod tests {
use super::*;
fn make_candles(data: &[(f64, f64, f64)]) -> Vec<Candle> {
data.iter()
.enumerate()
.map(|(i, &(h, l, c))| Candle {
time: i64::try_from(i).expect("time index fits i64"),
open: c,
high: h,
low: l,
close: c,
volume: 1.0,
})
.collect()
}
fn uniform_candles(n: usize, high: f64, low: f64, close: f64) -> Vec<Candle> {
make_candles(&vec![(high, low, close); n])
}
#[test]
fn stoch_insufficient_data() {
let err = Stochastic::default()
.calculate(&uniform_candles(5, 12.0, 8.0, 10.0))
.unwrap_err();
assert!(matches!(err, IndicatorError::InsufficientData { .. }));
}
#[test]
fn stoch_output_columns_exist() {
let out = Stochastic::default()
.calculate(&uniform_candles(30, 12.0, 8.0, 10.0))
.unwrap();
assert!(out.get("Stoch_K").is_some());
assert!(out.get("Stoch_D").is_some());
}
#[test]
fn stoch_known_value_midpoint() {
let out = Stochastic::new(StochParams {
k_period: 5,
smooth_k: 3,
d_period: 3,
})
.calculate(&uniform_candles(20, 12.0, 8.0, 10.0))
.unwrap();
let k = out.get("Stoch_K").unwrap();
let d = out.get("Stoch_D").unwrap();
let last_k = k.iter().rev().find(|v| !v.is_nan()).copied().unwrap();
let last_d = d.iter().rev().find(|v| !v.is_nan()).copied().unwrap();
assert!(
(last_k - 50.0).abs() < 1e-9,
"K expected 50.0, got {last_k}"
);
assert!(
(last_d - 50.0).abs() < 1e-9,
"D expected 50.0, got {last_d}"
);
}
#[test]
fn stoch_close_at_high_is_100() {
let out = Stochastic::new(StochParams {
k_period: 5,
smooth_k: 1,
d_period: 1,
})
.calculate(&uniform_candles(10, 12.0, 8.0, 12.0))
.unwrap();
let k = out.get("Stoch_K").unwrap();
for &v in k.iter().filter(|v| !v.is_nan()) {
assert!((v - 100.0).abs() < 1e-9, "expected 100.0, got {v}");
}
}
#[test]
fn stoch_close_at_low_is_0() {
let out = Stochastic::new(StochParams {
k_period: 5,
smooth_k: 1,
d_period: 1,
})
.calculate(&uniform_candles(10, 12.0, 8.0, 8.0))
.unwrap();
let k = out.get("Stoch_K").unwrap();
for &v in k.iter().filter(|v| !v.is_nan()) {
assert!(v.abs() < 1e-9, "expected 0.0, got {v}");
}
}
#[test]
fn stoch_range_0_to_100() {
let mut data = vec![];
for i in 0..15 {
let f = i as f64;
data.push((f + 1.0, f - 1.0, f));
}
for i in (0..10).rev() {
let f = i as f64;
data.push((f + 1.0, f - 1.0, f));
}
let out = Stochastic::default()
.calculate(&make_candles(&data))
.unwrap();
for &v in out.get("Stoch_K").unwrap() {
if !v.is_nan() {
assert!((0.0..=100.0).contains(&v), "K out of range: {v}");
}
}
for &v in out.get("Stoch_D").unwrap() {
if !v.is_nan() {
assert!((0.0..=100.0).contains(&v), "D out of range: {v}");
}
}
}
#[test]
fn stoch_no_smoothing_fast_stochastic() {
let out = Stochastic::new(StochParams {
k_period: 3,
smooth_k: 1,
d_period: 1,
})
.calculate(&uniform_candles(10, 10.0, 0.0, 6.0))
.unwrap();
let k = out.get("Stoch_K").unwrap();
for &v in k.iter().filter(|v| !v.is_nan()) {
assert!((v - 60.0).abs() < 1e-9, "expected 60.0, got {v}");
}
}
#[test]
fn factory_creates_stochastic() {
let ind = factory(&HashMap::new()).unwrap();
assert_eq!(ind.name(), "Stochastic");
}
}