use crate::indicators::utils::{validate_data_length, validate_period};
use crate::indicators::{Candle, Indicator, IndicatorError};
use std::collections::VecDeque;
#[derive(Debug)]
pub struct Rsi {
period: usize,
prev_price: Option<f64>,
gains: VecDeque<f64>,
losses: VecDeque<f64>,
avg_gain: Option<f64>,
avg_loss: Option<f64>,
}
impl Rsi {
pub fn new(period: usize) -> Result<Self, IndicatorError> {
validate_period(period, 1)?;
Ok(Self {
period,
prev_price: None,
gains: VecDeque::with_capacity(period),
losses: VecDeque::with_capacity(period),
avg_gain: None,
avg_loss: None,
})
}
fn calculate_rsi(avg_gain: f64, avg_loss: f64) -> f64 {
if avg_gain == 0.0 && avg_loss == 0.0 {
return 50.0;
}
if avg_loss == 0.0 {
return 100.0;
}
let rs = avg_gain / avg_loss;
100.0 - (100.0 / (1.0 + rs))
}
pub fn reset_state(&mut self) {
self.prev_price = None;
self.gains.clear();
self.losses.clear();
self.avg_gain = None;
self.avg_loss = None;
}
}
impl Indicator<f64, f64> for Rsi {
fn calculate(&mut self, data: &[f64]) -> Result<Vec<f64>, IndicatorError> {
validate_data_length(data, self.period + 1)?;
let n = data.len();
let mut result = Vec::with_capacity(n - self.period);
self.reset_state();
let mut price_changes = Vec::with_capacity(n - 1);
for i in 1..n {
price_changes.push(data[i] - data[i - 1]);
}
let mut gains = Vec::with_capacity(self.period);
let mut losses = Vec::with_capacity(self.period);
for &change in price_changes.iter().take(self.period) {
if change > 0.0 {
gains.push(change);
losses.push(0.0);
} else {
gains.push(0.0);
losses.push(-change);
}
}
let mut avg_gain = gains.iter().sum::<f64>() / self.period as f64;
let mut avg_loss = losses.iter().sum::<f64>() / self.period as f64;
result.push(Self::calculate_rsi(avg_gain, avg_loss));
for change in price_changes.iter().skip(self.period).copied() {
let gain = if change > 0.0 { change } else { 0.0 };
let loss = if change < 0.0 { -change } else { 0.0 };
avg_gain = (avg_gain * (self.period - 1) as f64 + gain) / self.period as f64;
avg_loss = (avg_loss * (self.period - 1) as f64 + loss) / self.period as f64;
result.push(Self::calculate_rsi(avg_gain, avg_loss));
}
Ok(result)
}
fn next(&mut self, value: f64) -> Result<Option<f64>, IndicatorError> {
if let Some(prev) = self.prev_price {
let change = value - prev;
let gain = if change > 0.0 { change } else { 0.0 };
let loss = if change < 0.0 { -change } else { 0.0 };
self.gains.push_back(gain);
self.losses.push_back(loss);
if self.gains.len() > self.period {
self.gains.pop_front();
self.losses.pop_front();
}
if self.gains.len() < self.period {
self.avg_gain = None;
self.avg_loss = None;
self.prev_price = Some(value);
return Ok(None);
}
if let (Some(avg_gain), Some(avg_loss)) = (self.avg_gain, self.avg_loss) {
self.avg_gain =
Some((avg_gain * (self.period - 1) as f64 + gain) / self.period as f64);
self.avg_loss =
Some((avg_loss * (self.period - 1) as f64 + loss) / self.period as f64);
} else {
self.avg_gain = Some(self.gains.iter().sum::<f64>() / self.period as f64);
self.avg_loss = Some(self.losses.iter().sum::<f64>() / self.period as f64);
}
let rsi = Self::calculate_rsi(self.avg_gain.unwrap(), self.avg_loss.unwrap());
self.prev_price = Some(value);
Ok(Some(rsi))
} else {
self.prev_price = Some(value);
Ok(None)
}
}
fn reset(&mut self) {
self.reset_state();
}
}
impl Indicator<Candle, f64> for Rsi {
fn calculate(&mut self, data: &[Candle]) -> Result<Vec<f64>, IndicatorError> {
validate_data_length(data, self.period + 1)?;
let n = data.len();
let mut result = Vec::with_capacity(n - self.period);
self.reset_state();
let close_prices: Vec<f64> = data.iter().map(|candle| candle.close).collect();
let mut price_changes = Vec::with_capacity(n - 1);
for i in 1..n {
price_changes.push(close_prices[i] - close_prices[i - 1]);
}
let mut gains = Vec::with_capacity(self.period);
let mut losses = Vec::with_capacity(self.period);
for &change in price_changes.iter().take(self.period) {
if change > 0.0 {
gains.push(change);
losses.push(0.0);
} else {
gains.push(0.0);
losses.push(-change);
}
}
let mut avg_gain = gains.iter().sum::<f64>() / self.period as f64;
let mut avg_loss = losses.iter().sum::<f64>() / self.period as f64;
result.push(Self::calculate_rsi(avg_gain, avg_loss));
for change in price_changes.iter().skip(self.period).copied() {
let gain = if change > 0.0 { change } else { 0.0 };
let loss = if change < 0.0 { -change } else { 0.0 };
avg_gain = (avg_gain * (self.period - 1) as f64 + gain) / self.period as f64;
avg_loss = (avg_loss * (self.period - 1) as f64 + loss) / self.period as f64;
result.push(Self::calculate_rsi(avg_gain, avg_loss));
}
Ok(result)
}
fn next(&mut self, candle: Candle) -> Result<Option<f64>, IndicatorError> {
let close_price = candle.close;
if let Some(prev) = self.prev_price {
let change = close_price - prev;
let gain = if change > 0.0 { change } else { 0.0 };
let loss = if change < 0.0 { -change } else { 0.0 };
self.gains.push_back(gain);
self.losses.push_back(loss);
if self.gains.len() > self.period {
self.gains.pop_front();
self.losses.pop_front();
}
if self.gains.len() < self.period {
self.avg_gain = None;
self.avg_loss = None;
self.prev_price = Some(close_price);
return Ok(None);
}
if let (Some(avg_gain), Some(avg_loss)) = (self.avg_gain, self.avg_loss) {
self.avg_gain =
Some((avg_gain * (self.period - 1) as f64 + gain) / self.period as f64);
self.avg_loss =
Some((avg_loss * (self.period - 1) as f64 + loss) / self.period as f64);
} else {
self.avg_gain = Some(self.gains.iter().sum::<f64>() / self.period as f64);
self.avg_loss = Some(self.losses.iter().sum::<f64>() / self.period as f64);
}
let rsi = Self::calculate_rsi(self.avg_gain.unwrap(), self.avg_loss.unwrap());
self.prev_price = Some(close_price);
Ok(Some(rsi))
} else {
self.prev_price = Some(close_price);
Ok(None)
}
}
fn reset(&mut self) {
self.reset_state();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rsi_new() {
assert!(Rsi::new(14).is_ok());
assert!(Rsi::new(0).is_err());
}
#[test]
fn test_rsi_calculation() {
let mut rsi = Rsi::new(3).unwrap();
let prices = vec![10.0, 11.0, 10.5, 11.5, 12.0, 11.0, 11.5];
let result = rsi.calculate(&prices).unwrap();
assert_eq!(result.len(), 4);
assert!((result[0] - 80.0).abs() < 0.01);
let last_value = result.last().unwrap();
assert!(*last_value >= 0.0 && *last_value <= 100.0);
}
#[test]
fn test_rsi_next() {
let mut rsi = Rsi::new(3).unwrap();
assert_eq!(rsi.next(10.0).unwrap(), None);
assert_eq!(rsi.next(11.0).unwrap(), None);
assert_eq!(rsi.next(10.5).unwrap(), None);
let first_rsi = rsi.next(11.5).unwrap();
assert!(first_rsi.is_some());
let first_rsi_value = first_rsi.unwrap();
assert!((0.0..=100.0).contains(&first_rsi_value));
assert!(rsi.next(12.0).unwrap().is_some());
assert!(rsi.next(11.0).unwrap().is_some());
}
#[test]
fn test_rsi_reset() {
let mut rsi = Rsi::new(3).unwrap();
rsi.next(10.0).unwrap();
rsi.next(11.0).unwrap();
rsi.next(10.5).unwrap();
rsi.next(11.5).unwrap();
rsi.reset_state();
assert_eq!(rsi.next(12.0).unwrap(), None);
}
#[test]
fn test_rsi_calculation_with_candles() {
let mut rsi = Rsi::new(3).unwrap();
let candles = vec![
Candle {
timestamp: 1,
open: 9.0,
high: 10.5,
low: 8.5,
close: 10.0,
volume: 1000.0,
},
Candle {
timestamp: 2,
open: 10.5,
high: 11.5,
low: 10.0,
close: 11.0,
volume: 1200.0,
},
Candle {
timestamp: 3,
open: 11.0,
high: 11.5,
low: 10.0,
close: 10.5,
volume: 1100.0,
},
Candle {
timestamp: 4,
open: 10.0,
high: 12.0,
low: 10.0,
close: 11.5,
volume: 1300.0,
},
Candle {
timestamp: 5,
open: 11.5,
high: 12.5,
low: 11.0,
close: 12.0,
volume: 1400.0,
},
Candle {
timestamp: 6,
open: 12.0,
high: 12.0,
low: 10.5,
close: 11.0,
volume: 1500.0,
},
Candle {
timestamp: 7,
open: 11.0,
high: 12.0,
low: 11.0,
close: 11.5,
volume: 1600.0,
},
];
let result = rsi.calculate(&candles).unwrap();
assert_eq!(result.len(), 4);
assert!((result[0] - 80.0).abs() < 0.01);
let close_prices: Vec<f64> = candles.iter().map(|c| c.close).collect();
let mut price_rsi = Rsi::new(3).unwrap();
let price_result = price_rsi.calculate(&close_prices).unwrap();
for (candle_rsi, price_rsi) in result.iter().zip(price_result.iter()) {
assert!((candle_rsi - price_rsi).abs() < 0.000001);
}
}
#[test]
fn test_rsi_next_with_candles() {
let mut rsi = Rsi::new(3).unwrap();
let candles = [
Candle {
timestamp: 1,
open: 9.0,
high: 10.5,
low: 8.5,
close: 10.0,
volume: 1000.0,
},
Candle {
timestamp: 2,
open: 10.5,
high: 11.5,
low: 10.0,
close: 11.0,
volume: 1200.0,
},
Candle {
timestamp: 3,
open: 11.0,
high: 11.5,
low: 10.0,
close: 10.5,
volume: 1100.0,
},
Candle {
timestamp: 4,
open: 10.0,
high: 12.0,
low: 10.0,
close: 11.5,
volume: 1300.0,
},
];
assert_eq!(rsi.next(candles[0]).unwrap(), None);
assert_eq!(rsi.next(candles[1]).unwrap(), None);
assert_eq!(rsi.next(candles[2]).unwrap(), None);
let first_rsi = rsi.next(candles[3]).unwrap();
assert!(first_rsi.is_some());
let first_rsi_value = first_rsi.unwrap();
assert!((0.0..=100.0).contains(&first_rsi_value));
let mut price_rsi = Rsi::new(3).unwrap();
price_rsi.next(candles[0].close).unwrap();
price_rsi.next(candles[1].close).unwrap();
price_rsi.next(candles[2].close).unwrap();
let price_first_rsi = price_rsi.next(candles[3].close).unwrap().unwrap();
assert!((first_rsi_value - price_first_rsi).abs() < 0.000001);
}
#[test]
fn test_rsi_with_candles_ignores_other_price_data() {
let mut rsi = Rsi::new(3).unwrap();
let candles = vec![
Candle {
timestamp: 1,
open: 15.0,
high: 20.0,
low: 5.0,
close: 10.0,
volume: 5000.0,
},
Candle {
timestamp: 2,
open: 25.0,
high: 30.0,
low: 8.0,
close: 11.0,
volume: 6000.0,
},
Candle {
timestamp: 3,
open: 5.0,
high: 15.0,
low: 2.0,
close: 10.5,
volume: 7000.0,
},
Candle {
timestamp: 4,
open: 20.0,
high: 25.0,
low: 9.0,
close: 11.5,
volume: 8000.0,
},
];
let candles2 = vec![
Candle {
timestamp: 1,
open: 9.0,
high: 10.5,
low: 8.5,
close: 10.0,
volume: 1000.0,
},
Candle {
timestamp: 2,
open: 10.5,
high: 11.5,
low: 10.0,
close: 11.0,
volume: 1200.0,
},
Candle {
timestamp: 3,
open: 11.0,
high: 11.5,
low: 10.0,
close: 10.5,
volume: 1100.0,
},
Candle {
timestamp: 4,
open: 10.0,
high: 12.0,
low: 10.0,
close: 11.5,
volume: 1300.0,
},
];
let result1 = rsi.calculate(&candles).unwrap();
rsi.reset_state();
let result2 = rsi.calculate(&candles2).unwrap();
assert_eq!(result1.len(), result2.len());
for (val1, val2) in result1.iter().zip(result2.iter()) {
assert!((val1 - val2).abs() < 0.000001);
}
}
#[test]
fn test_rsi_with_candles_reset() {
let mut rsi = Rsi::new(3).unwrap();
rsi.next(Candle {
timestamp: 1,
open: 9.0,
high: 10.5,
low: 8.5,
close: 10.0,
volume: 1000.0,
})
.unwrap();
rsi.next(Candle {
timestamp: 2,
open: 10.5,
high: 11.5,
low: 10.0,
close: 11.0,
volume: 1200.0,
})
.unwrap();
rsi.next(Candle {
timestamp: 3,
open: 11.0,
high: 11.5,
low: 10.0,
close: 10.5,
volume: 1100.0,
})
.unwrap();
rsi.next(Candle {
timestamp: 4,
open: 10.0,
high: 12.0,
low: 10.0,
close: 11.5,
volume: 1300.0,
})
.unwrap();
rsi.reset_state();
assert_eq!(
rsi.next(Candle {
timestamp: 5,
open: 11.5,
high: 12.5,
low: 11.0,
close: 12.0,
volume: 1400.0
})
.unwrap(),
None
);
}
#[test]
fn test_rsi_with_candles_edge_cases() {
let mut rsi = Rsi::new(3).unwrap();
let flat_candles = vec![
Candle {
timestamp: 1,
open: 9.0,
high: 10.5,
low: 8.5,
close: 10.0,
volume: 1000.0,
},
Candle {
timestamp: 2,
open: 10.5,
high: 11.5,
low: 10.0,
close: 10.0,
volume: 1200.0,
},
Candle {
timestamp: 3,
open: 11.0,
high: 11.5,
low: 10.0,
close: 10.0,
volume: 1100.0,
},
Candle {
timestamp: 4,
open: 10.0,
high: 12.0,
low: 10.0,
close: 10.0,
volume: 1300.0,
},
];
let result = rsi.calculate(&flat_candles).unwrap();
assert_eq!(result[0], 50.0);
let mut rsi = Rsi::new(3).unwrap();
let up_candles = vec![
Candle {
timestamp: 1,
open: 9.0,
high: 10.5,
low: 8.5,
close: 10.0,
volume: 1000.0,
},
Candle {
timestamp: 2,
open: 10.5,
high: 11.5,
low: 10.0,
close: 11.0,
volume: 1200.0,
},
Candle {
timestamp: 3,
open: 11.0,
high: 11.5,
low: 10.0,
close: 12.0,
volume: 1100.0,
},
Candle {
timestamp: 4,
open: 12.0,
high: 13.0,
low: 11.0,
close: 13.0,
volume: 1300.0,
},
];
let result = rsi.calculate(&up_candles).unwrap();
assert_eq!(result[0], 100.0);
let mut rsi = Rsi::new(3).unwrap();
let down_candles = vec![
Candle {
timestamp: 1,
open: 14.0,
high: 15.0,
low: 13.5,
close: 14.0,
volume: 1000.0,
},
Candle {
timestamp: 2,
open: 13.5,
high: 14.0,
low: 13.0,
close: 13.0,
volume: 1200.0,
},
Candle {
timestamp: 3,
open: 13.0,
high: 13.0,
low: 12.0,
close: 12.0,
volume: 1100.0,
},
Candle {
timestamp: 4,
open: 12.0,
high: 12.0,
low: 11.0,
close: 11.0,
volume: 1300.0,
},
];
let result = rsi.calculate(&down_candles).unwrap();
assert_eq!(result[0], 0.0); }
}