use std::collections::VecDeque;
use super::{IndicatorError, Result, rsi::rsi_raw, sma::sma_raw, stochastic::StochasticResult};
pub fn stochastic_rsi(
data: &[f64],
rsi_period: usize,
stoch_period: usize,
k_period: usize,
d_period: usize,
) -> Result<StochasticResult> {
if rsi_period == 0 || stoch_period == 0 || k_period == 0 || d_period == 0 {
return Err(IndicatorError::InvalidPeriod(
"Periods must be greater than 0".to_string(),
));
}
let min_len = rsi_period + stoch_period;
if data.len() < min_len {
return Err(IndicatorError::InsufficientData {
need: min_len,
got: data.len(),
});
}
let rsi_dense = rsi_raw(data, rsi_period)?;
let len = data.len();
stochastic_rsi_from_rsi_dense(
&rsi_dense,
len,
rsi_period,
stoch_period,
k_period,
d_period,
)
}
pub(crate) fn stochastic_rsi_from_rsi_dense(
rsi_dense: &[f64],
len: usize,
rsi_period: usize,
stoch_period: usize,
k_period: usize,
d_period: usize,
) -> Result<StochasticResult> {
if stoch_period == 0 || k_period == 0 || d_period == 0 {
return Err(IndicatorError::InvalidPeriod(
"Periods must be greater than 0".to_string(),
));
}
let rsi_len = rsi_dense.len();
if rsi_len < stoch_period {
return Err(IndicatorError::InsufficientData {
need: rsi_period + stoch_period,
got: len,
});
}
let raw_start = rsi_period + stoch_period - 1;
let raw_count = rsi_len.saturating_sub(stoch_period - 1);
let mut raw_stoch_dense = Vec::with_capacity(raw_count);
{
let mut max_deque: VecDeque<usize> = VecDeque::new();
let mut min_deque: VecDeque<usize> = VecDeque::new();
for k in 0..rsi_len {
while max_deque.front().is_some_and(|&j| j + stoch_period <= k) {
max_deque.pop_front();
}
while min_deque.front().is_some_and(|&j| j + stoch_period <= k) {
min_deque.pop_front();
}
while max_deque
.back()
.is_some_and(|&j| rsi_dense[j] <= rsi_dense[k])
{
max_deque.pop_back();
}
while min_deque
.back()
.is_some_and(|&j| rsi_dense[j] >= rsi_dense[k])
{
min_deque.pop_back();
}
max_deque.push_back(k);
min_deque.push_back(k);
if k + 1 >= stoch_period {
let max_rsi = rsi_dense[*max_deque.front().unwrap()];
let min_rsi = rsi_dense[*min_deque.front().unwrap()];
let range = max_rsi - min_rsi;
raw_stoch_dense.push(if range.abs() < f64::EPSILON {
50.0
} else {
(rsi_dense[k] - min_rsi) / range * 100.0
});
}
}
}
let k_dense: Vec<f64>;
let (k_line, k_valid_start) = if k_period == 1 {
k_dense = raw_stoch_dense.clone();
let mut k_line = vec![None; len];
for (j, &v) in raw_stoch_dense.iter().enumerate() {
k_line[j + raw_start] = Some(v);
}
(k_line, raw_start)
} else {
k_dense = sma_raw(&raw_stoch_dense, k_period);
let k_start = raw_start + k_period - 1;
let mut k_line = vec![None; len];
for (j, &val) in k_dense.iter().enumerate() {
let idx = j + k_start;
if idx < len {
k_line[idx] = Some(val);
}
}
(k_line, k_start)
};
let d_line = if d_period == 1 {
k_line.clone()
} else {
let d_raw = sma_raw(&k_dense, d_period);
let d_start = k_valid_start + d_period - 1;
let mut d_line = vec![None; len];
for (j, &val) in d_raw.iter().enumerate() {
let idx = j + d_start;
if idx < len {
d_line[idx] = Some(val);
}
}
d_line
};
Ok(StochasticResult {
k: k_line,
d: d_line,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stochastic_rsi_no_smoothing() {
let prices = vec![10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0];
let result = stochastic_rsi(&prices, 3, 3, 1, 1).unwrap();
assert_eq!(result.k.len(), 9);
assert!(result.k[0].is_none());
assert!(result.k[4].is_none());
assert!(result.k[5].is_some());
}
#[test]
fn test_stochastic_rsi_with_smoothing() {
let prices: Vec<f64> = (1..=30).map(|i| i as f64).collect();
let result = stochastic_rsi(&prices, 3, 3, 3, 3).unwrap();
assert_eq!(result.k.len(), 30);
assert!(result.k[6].is_none() || result.k[6].is_some()); let k_last = result.k.iter().filter_map(|v| *v).next_back();
let d_last = result.d.iter().filter_map(|v| *v).next_back();
assert!(k_last.is_some());
assert!(d_last.is_some());
}
#[test]
fn test_stochastic_rsi_k_before_d() {
let prices: Vec<f64> = (1..=40).map(|i| i as f64).collect();
let result = stochastic_rsi(&prices, 5, 5, 3, 3).unwrap();
let k_first = result.k.iter().position(|v| v.is_some());
let d_first = result.d.iter().position(|v| v.is_some());
assert!(k_first.is_some());
assert!(d_first.is_some());
assert!(d_first.unwrap() >= k_first.unwrap());
}
}