use std::collections::VecDeque;
use crate::error::{Error, Result};
use crate::traits::Indicator;
use super::Rsi;
#[derive(Debug, Clone)]
pub struct StochRsi {
rsi_period: usize,
stoch_period: usize,
rsi: Rsi,
window: VecDeque<f64>,
last: Option<f64>,
}
impl StochRsi {
pub fn new(rsi_period: usize, stoch_period: usize) -> Result<Self> {
if rsi_period == 0 || stoch_period == 0 {
return Err(Error::PeriodZero);
}
Ok(Self {
rsi_period,
stoch_period,
rsi: Rsi::new(rsi_period)?,
window: VecDeque::with_capacity(stoch_period),
last: None,
})
}
pub const fn periods(&self) -> (usize, usize) {
(self.rsi_period, self.stoch_period)
}
pub const fn value(&self) -> Option<f64> {
self.last
}
}
impl Indicator for StochRsi {
type Input = f64;
type Output = f64;
fn update(&mut self, input: f64) -> Option<f64> {
if !input.is_finite() {
return self.last;
}
let rsi_value = self.rsi.update(input)?;
if self.window.len() == self.stoch_period {
self.window.pop_front();
}
self.window.push_back(rsi_value);
if self.window.len() < self.stoch_period {
return None;
}
let max = self
.window
.iter()
.copied()
.fold(f64::NEG_INFINITY, f64::max);
let min = self.window.iter().copied().fold(f64::INFINITY, f64::min);
let range = max - min;
let stoch = if range == 0.0 {
50.0
} else {
100.0 * (rsi_value - min) / range
};
self.last = Some(stoch);
Some(stoch)
}
fn reset(&mut self) {
self.rsi.reset();
self.window.clear();
self.last = None;
}
fn warmup_period(&self) -> usize {
self.rsi_period + self.stoch_period
}
fn is_ready(&self) -> bool {
self.last.is_some()
}
fn name(&self) -> &'static str {
"StochRSI"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::BatchExt;
use approx::assert_relative_eq;
#[test]
fn new_rejects_zero_period() {
assert!(matches!(StochRsi::new(0, 14), Err(Error::PeriodZero)));
assert!(matches!(StochRsi::new(14, 0), Err(Error::PeriodZero)));
}
#[test]
fn accessors_and_metadata() {
let mut sr = StochRsi::new(14, 14).unwrap();
assert_eq!(sr.periods(), (14, 14));
assert_eq!(sr.name(), "StochRSI");
assert_eq!(sr.value(), None);
for i in 1..=sr.warmup_period() {
sr.update(100.0 + f64::from(u32::try_from(i).unwrap()));
}
assert!(sr.value().is_some());
}
#[test]
fn first_emission_at_warmup_period() {
let mut sr = StochRsi::new(5, 4).unwrap();
assert_eq!(sr.warmup_period(), 9);
let prices: Vec<f64> = (1..=40)
.map(|i| 100.0 + (f64::from(i) * 0.6).sin() * 8.0)
.collect();
let out = sr.batch(&prices);
for v in out.iter().take(8) {
assert!(v.is_none());
}
assert!(out[8].is_some());
}
#[test]
fn flat_rsi_window_yields_50() {
let mut sr = StochRsi::new(5, 4).unwrap();
let out = sr.batch(&[100.0; 40]);
for v in out.iter().skip(9).flatten() {
assert_relative_eq!(*v, 50.0, epsilon = 1e-12);
}
}
#[test]
fn pure_uptrend_yields_50() {
let mut sr = StochRsi::new(5, 4).unwrap();
let out = sr.batch(&(1..=40).map(f64::from).collect::<Vec<_>>());
for v in out.iter().skip(9).flatten() {
assert_relative_eq!(*v, 50.0, epsilon = 1e-12);
}
}
#[test]
fn output_stays_within_0_100() {
let mut sr = StochRsi::new(14, 14).unwrap();
let prices: Vec<f64> = (1..=200)
.map(|i| 100.0 + (f64::from(i) * 0.3).sin() * 15.0 + (f64::from(i) * 0.07).cos() * 6.0)
.collect();
for v in sr.batch(&prices).into_iter().flatten() {
assert!((0.0..=100.0).contains(&v), "StochRSI out of range: {v}");
}
}
#[test]
fn ignores_non_finite_input() {
let mut sr = StochRsi::new(5, 4).unwrap();
let prices: Vec<f64> = (1..=40)
.map(|i| 100.0 + (f64::from(i) * 0.6).sin() * 8.0)
.collect();
let out = sr.batch(&prices);
let last = *out.last().unwrap();
assert!(last.is_some());
assert_eq!(sr.update(f64::NAN), last);
assert_eq!(sr.update(f64::INFINITY), last);
}
#[test]
fn reset_clears_state() {
let mut sr = StochRsi::new(5, 4).unwrap();
sr.batch(&(1..=40).map(f64::from).collect::<Vec<_>>());
assert!(sr.is_ready());
sr.reset();
assert!(!sr.is_ready());
assert_eq!(sr.update(1.0), None);
}
#[test]
fn batch_equals_streaming() {
let prices: Vec<f64> = (1..=120)
.map(|i| 100.0 + (f64::from(i) * 0.25).sin() * 12.0)
.collect();
let batch = StochRsi::new(14, 14).unwrap().batch(&prices);
let mut b = StochRsi::new(14, 14).unwrap();
let streamed: Vec<_> = prices.iter().map(|p| b.update(*p)).collect();
assert_eq!(batch, streamed);
}
}