use std::collections::HashMap;
use crate::error::IndicatorError;
use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
use crate::registry::{param_str, param_usize};
use crate::types::Candle;
#[derive(Debug, Clone)]
pub struct RsiParams {
pub period: usize,
pub column: PriceColumn,
}
impl Default for RsiParams {
fn default() -> Self {
Self {
period: 14,
column: PriceColumn::Close,
}
}
}
#[derive(Debug, Clone)]
pub struct Rsi {
pub params: RsiParams,
}
impl Rsi {
pub fn new(params: RsiParams) -> Self {
Self { params }
}
pub fn with_period(period: usize) -> Self {
Self::new(RsiParams {
period,
..Default::default()
})
}
fn output_key(&self) -> String {
format!("RSI_{}", self.params.period)
}
}
impl Indicator for Rsi {
fn name(&self) -> &'static str {
"RSI"
}
fn required_len(&self) -> usize {
self.params.period + 1
}
fn required_columns(&self) -> &[&'static str] {
&["close"]
}
fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
self.check_len(candles)?;
let prices = self.params.column.extract(candles);
let n = prices.len();
let p = self.params.period;
let mut values = vec![f64::NAN; n];
let mut avg_gain = 0.0_f64;
let mut avg_loss = 0.0_f64;
for i in 1..=p {
let delta = prices[i] - prices[i - 1];
if delta > 0.0 {
avg_gain += delta;
} else {
avg_loss += -delta;
}
}
avg_gain /= p as f64;
avg_loss /= p as f64;
values[p] = rsi_from(avg_gain, avg_loss);
let w = (p - 1) as f64;
for i in (p + 1)..n {
let delta = prices[i] - prices[i - 1];
let gain = if delta > 0.0 { delta } else { 0.0 };
let loss = if delta < 0.0 { -delta } else { 0.0 };
avg_gain = (avg_gain * w + gain) / p as f64;
avg_loss = (avg_loss * w + loss) / p as f64;
values[i] = rsi_from(avg_gain, avg_loss);
}
Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
}
}
#[inline]
fn rsi_from(avg_gain: f64, avg_loss: f64) -> f64 {
if avg_loss == 0.0 {
if avg_gain == 0.0 { 50.0 } else { 100.0 }
} else {
100.0 - 100.0 / (1.0 + avg_gain / avg_loss)
}
}
pub fn factory<S: ::std::hash::BuildHasher>(
params: &HashMap<String, String, S>,
) -> Result<Box<dyn Indicator>, IndicatorError> {
let period = param_usize(params, "period", 14)?;
let column = match param_str(params, "column", "close") {
"open" => PriceColumn::Open,
"high" => PriceColumn::High,
"low" => PriceColumn::Low,
"volume" => PriceColumn::Volume,
_ => PriceColumn::Close,
};
Ok(Box::new(Rsi::new(RsiParams { period, column })))
}
#[cfg(test)]
mod tests {
use super::*;
fn make_candles(closes: &[f64]) -> Vec<Candle> {
closes
.iter()
.enumerate()
.map(|(i, &c)| Candle {
time: i64::try_from(i).expect("time index fits i64"),
open: c,
high: c,
low: c,
close: c,
volume: 1.0,
})
.collect()
}
#[test]
fn rsi_insufficient_data() {
let err = Rsi::with_period(14)
.calculate(&make_candles(&[1.0; 10]))
.unwrap_err();
assert!(matches!(err, IndicatorError::InsufficientData { .. }));
}
#[test]
fn rsi_leading_nans() {
let prices: Vec<f64> = (0..20).map(|i| i as f64).collect();
let out = Rsi::with_period(14)
.calculate(&make_candles(&prices))
.unwrap();
let vals = out.get("RSI_14").unwrap();
for (i, &v) in vals.iter().enumerate().take(14) {
assert!(v.is_nan(), "expected NaN at [{i}], got {v}");
}
assert!(!vals[14].is_nan());
}
#[test]
fn rsi_constant_gains_is_100() {
let prices: Vec<f64> = (0..20).map(|i| i as f64).collect();
let out = Rsi::with_period(14)
.calculate(&make_candles(&prices))
.unwrap();
for &v in out.get("RSI_14").unwrap().iter().filter(|v| !v.is_nan()) {
assert!((v - 100.0).abs() < 1e-9, "expected 100.0, got {v}");
}
}
#[test]
fn rsi_constant_losses_is_0() {
let prices: Vec<f64> = (0..20).map(|i| 100.0 - i as f64).collect();
let out = Rsi::with_period(14)
.calculate(&make_candles(&prices))
.unwrap();
for &v in out.get("RSI_14").unwrap().iter().filter(|v| !v.is_nan()) {
assert!(v.abs() < 1e-9, "expected 0.0, got {v}");
}
}
#[test]
fn rsi_alternating_equal_moves_is_50() {
let mut prices = vec![100.0_f64];
for i in 0..19 {
let last = *prices.last().unwrap();
prices.push(if i % 2 == 0 { last + 1.0 } else { last - 1.0 });
}
let out = Rsi::with_period(14)
.calculate(&make_candles(&prices))
.unwrap();
assert!((out.get("RSI_14").unwrap()[14] - 50.0).abs() < 1e-9);
}
#[test]
fn rsi_known_seed_value() {
let out = Rsi::with_period(3)
.calculate(&make_candles(&[10.0, 11.0, 9.0, 11.0]))
.unwrap();
assert!((out.get("RSI_3").unwrap()[3] - 60.0).abs() < 1e-6);
}
#[test]
fn rsi_wilder_smoothing_step() {
let out = Rsi::with_period(3)
.calculate(&make_candles(&[10.0, 11.0, 9.0, 11.0, 10.0]))
.unwrap();
let ag = (1.0_f64 * 2.0) / 3.0;
let al = (2.0_f64 / 3.0 * 2.0 + 1.0) / 3.0;
let expected = 100.0 - 100.0 / (1.0 + ag / al);
assert!((out.get("RSI_3").unwrap()[4] - expected).abs() < 1e-9);
}
#[test]
fn rsi_stays_in_range() {
let prices: Vec<f64> = (0..50)
.map(|i| 100.0 + (i as f64 * 0.3).sin() * 10.0)
.collect();
let out = Rsi::with_period(14)
.calculate(&make_candles(&prices))
.unwrap();
for &v in out.get("RSI_14").unwrap() {
if !v.is_nan() {
assert!((0.0..=100.0).contains(&v), "out of range: {v}");
}
}
}
#[test]
fn factory_creates_rsi() {
let ind = factory(&HashMap::new()).unwrap();
assert_eq!(ind.name(), "RSI");
assert_eq!(ind.required_len(), 15);
}
}