use crate::error::{Error, Result};
use crate::ohlcv::Candle;
use crate::traits::Indicator;
#[derive(Debug, Clone)]
pub struct VolumeRsi {
period: usize,
prev_volume: Option<f64>,
seed_gains: f64,
seed_losses: f64,
seed_count: usize,
avg_gain: Option<f64>,
avg_loss: Option<f64>,
last: Option<f64>,
}
impl VolumeRsi {
pub fn new(period: usize) -> Result<Self> {
if period == 0 {
return Err(Error::PeriodZero);
}
Ok(Self {
period,
prev_volume: None,
seed_gains: 0.0,
seed_losses: 0.0,
seed_count: 0,
avg_gain: None,
avg_loss: None,
last: None,
})
}
pub const fn period(&self) -> usize {
self.period
}
pub const fn value(&self) -> Option<f64> {
self.last
}
fn rsi_from_avgs(avg_gain: f64, avg_loss: f64) -> f64 {
let denom = avg_gain + avg_loss;
if denom == 0.0 {
50.0
} else {
100.0 * (avg_gain / denom)
}
}
}
impl Indicator for VolumeRsi {
type Input = Candle;
type Output = f64;
fn update(&mut self, candle: Candle) -> Option<f64> {
let volume = candle.volume;
let Some(prev) = self.prev_volume else {
self.prev_volume = Some(volume);
return None;
};
let change = volume - prev;
self.prev_volume = Some(volume);
let gain = if change > 0.0 { change } else { 0.0 };
let loss = if change < 0.0 { -change } else { 0.0 };
if let (Some(ag), Some(al)) = (self.avg_gain, self.avg_loss) {
let n = self.period as f64;
let new_ag = (ag * (n - 1.0) + gain) / n;
let new_al = (al * (n - 1.0) + loss) / n;
self.avg_gain = Some(new_ag);
self.avg_loss = Some(new_al);
let v = Self::rsi_from_avgs(new_ag, new_al);
self.last = Some(v);
return Some(v);
}
self.seed_gains += gain;
self.seed_losses += loss;
self.seed_count += 1;
if self.seed_count == self.period {
let n = self.period as f64;
let ag = self.seed_gains / n;
let al = self.seed_losses / n;
self.avg_gain = Some(ag);
self.avg_loss = Some(al);
let v = Self::rsi_from_avgs(ag, al);
self.last = Some(v);
return Some(v);
}
None
}
fn reset(&mut self) {
self.prev_volume = None;
self.seed_gains = 0.0;
self.seed_losses = 0.0;
self.seed_count = 0;
self.avg_gain = None;
self.avg_loss = None;
self.last = None;
}
fn warmup_period(&self) -> usize {
self.period + 1
}
fn is_ready(&self) -> bool {
self.last.is_some()
}
fn name(&self) -> &'static str {
"VolumeRsi"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::BatchExt;
use approx::assert_relative_eq;
fn vol_candle(volume: f64) -> Candle {
Candle::new_unchecked(100.0, 101.0, 99.0, 100.5, volume, 0)
}
#[test]
fn rejects_zero_period() {
assert!(matches!(VolumeRsi::new(0), Err(Error::PeriodZero)));
}
#[test]
fn accessors_and_metadata() {
let v = VolumeRsi::new(14).unwrap();
assert_eq!(v.period(), 14);
assert_eq!(v.warmup_period(), 15);
assert_eq!(v.name(), "VolumeRsi");
assert!(!v.is_ready());
assert_eq!(v.value(), None);
}
#[test]
fn first_emission_at_warmup_period() {
let mut v = VolumeRsi::new(3).unwrap();
let candles: Vec<Candle> = (0..6).map(|i| vol_candle(1_000.0 + f64::from(i))).collect();
let out = v.batch(&candles);
for o in out.iter().take(3) {
assert!(o.is_none());
}
assert!(out[3].is_some());
}
#[test]
fn rising_volume_is_one_hundred() {
let mut v = VolumeRsi::new(5).unwrap();
let candles: Vec<Candle> = (1..=40).map(|i| vol_candle(f64::from(i) * 100.0)).collect();
let last = v.batch(&candles).into_iter().flatten().last().unwrap();
assert_relative_eq!(last, 100.0, epsilon = 1e-9);
}
#[test]
fn falling_volume_is_zero() {
let mut v = VolumeRsi::new(5).unwrap();
let candles: Vec<Candle> = (1..=40)
.map(|i| vol_candle(5_000.0 - f64::from(i) * 100.0))
.collect();
let last = v.batch(&candles).into_iter().flatten().last().unwrap();
assert_relative_eq!(last, 0.0, epsilon = 1e-9);
}
#[test]
fn flat_volume_is_neutral() {
let mut v = VolumeRsi::new(3).unwrap();
let candles: Vec<Candle> = (0..20).map(|_| vol_candle(2_000.0)).collect();
let last = v.batch(&candles).into_iter().flatten().last().unwrap();
assert_relative_eq!(last, 50.0, epsilon = 1e-12);
}
#[test]
fn output_in_range() {
let mut v = VolumeRsi::new(14).unwrap();
let candles: Vec<Candle> = (0..200)
.map(|i| vol_candle(1_000.0 + (f64::from(i) * 0.3).sin() * 600.0))
.collect();
for o in v.batch(&candles).into_iter().flatten() {
assert!((0.0..=100.0).contains(&o));
}
}
#[test]
fn reset_clears_state() {
let mut v = VolumeRsi::new(3).unwrap();
let candles: Vec<Candle> = (0..20)
.map(|i| vol_candle(1_000.0 + f64::from(i)))
.collect();
v.batch(&candles);
assert!(v.is_ready());
v.reset();
assert!(!v.is_ready());
assert_eq!(v.value(), None);
assert_eq!(v.update(vol_candle(1_000.0)), None);
}
#[test]
fn batch_equals_streaming() {
let candles: Vec<Candle> = (0..120)
.map(|i| vol_candle(1_000.0 + (f64::from(i) * 0.25).sin() * 500.0))
.collect();
let batch = VolumeRsi::new(14).unwrap().batch(&candles);
let mut b = VolumeRsi::new(14).unwrap();
let streamed: Vec<_> = candles.iter().map(|c| b.update(*c)).collect();
assert_eq!(batch, streamed);
}
}