Skip to main content

indicators/momentum/
rsi.rs

1//! Relative Strength Index (RSI) — Wilder's smoothed method.
2//!
3//! Python source: `indicators/momentum/rsi.py :: class RSI`
4//!
5//! # Algorithm
6//!
7//! 1. `delta[i] = close[i] - close[i-1]`
8//! 2. **Seed** (bars 1..=period): simple mean of gains and losses.
9//! 3. **Wilder smoothing** (bar > period):
10//!    `avg_gain = (prev * (period-1) + gain) / period`
11//! 4. `RSI = 100 - 100 / (1 + avg_gain / avg_loss)`
12//!
13//! This matches TA-Lib and TradingView (Wilder seeding, not SMA).
14//!
15//! Output column: `"RSI_{period}"` — e.g. `"RSI_14"`.
16
17use std::collections::HashMap;
18
19use crate::error::IndicatorError;
20use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
21use crate::registry::{param_str, param_usize};
22use crate::types::Candle;
23
24// ── Params ────────────────────────────────────────────────────────────────────
25
26#[derive(Debug, Clone)]
27pub struct RsiParams {
28    /// Look-back period. Wilder's original default: 14.
29    pub period: usize,
30    /// Price field. Default: Close.
31    pub column: PriceColumn,
32}
33
34impl Default for RsiParams {
35    fn default() -> Self {
36        Self {
37            period: 14,
38            column: PriceColumn::Close,
39        }
40    }
41}
42
43// ── Indicator struct ──────────────────────────────────────────────────────────
44
45#[derive(Debug, Clone)]
46pub struct Rsi {
47    pub params: RsiParams,
48}
49
50impl Rsi {
51    pub fn new(params: RsiParams) -> Self {
52        Self { params }
53    }
54    pub fn with_period(period: usize) -> Self {
55        Self::new(RsiParams {
56            period,
57            ..Default::default()
58        })
59    }
60    fn output_key(&self) -> String {
61        format!("RSI_{}", self.params.period)
62    }
63}
64
65impl Indicator for Rsi {
66    fn name(&self) -> &str {
67        "RSI"
68    }
69
70    /// Need `period + 1` bars: `period` deltas to seed, output starts at index `period`.
71    fn required_len(&self) -> usize {
72        self.params.period + 1
73    }
74
75    fn required_columns(&self) -> &[&'static str] {
76        &["close"]
77    }
78
79    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
80        self.check_len(candles)?;
81
82        let prices = self.params.column.extract(candles);
83        let n = prices.len();
84        let p = self.params.period;
85        let mut values = vec![f64::NAN; n];
86
87        // ── Seed: SMA of first `p` deltas ────────────────────────────────────
88        let mut avg_gain = 0.0_f64;
89        let mut avg_loss = 0.0_f64;
90        for i in 1..=p {
91            let delta = prices[i] - prices[i - 1];
92            if delta > 0.0 {
93                avg_gain += delta;
94            } else {
95                avg_loss += -delta;
96            }
97        }
98        avg_gain /= p as f64;
99        avg_loss /= p as f64;
100        values[p] = rsi_from(avg_gain, avg_loss);
101
102        // ── Wilder smoothing for remaining bars ───────────────────────────────
103        let w = (p - 1) as f64;
104        for i in (p + 1)..n {
105            let delta = prices[i] - prices[i - 1];
106            let gain = if delta > 0.0 { delta } else { 0.0 };
107            let loss = if delta < 0.0 { -delta } else { 0.0 };
108            avg_gain = (avg_gain * w + gain) / p as f64;
109            avg_loss = (avg_loss * w + loss) / p as f64;
110            values[i] = rsi_from(avg_gain, avg_loss);
111        }
112
113        Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
114    }
115}
116
117#[inline]
118fn rsi_from(avg_gain: f64, avg_loss: f64) -> f64 {
119    if avg_loss == 0.0 {
120        if avg_gain == 0.0 { 50.0 } else { 100.0 }
121    } else {
122        100.0 - 100.0 / (1.0 + avg_gain / avg_loss)
123    }
124}
125
126// ── Registry factory ──────────────────────────────────────────────────────────
127
128pub fn factory(params: &HashMap<String, String>) -> Result<Box<dyn Indicator>, IndicatorError> {
129    let period = param_usize(params, "period", 14)?;
130    let column = match param_str(params, "column", "close") {
131        "open" => PriceColumn::Open,
132        "high" => PriceColumn::High,
133        "low" => PriceColumn::Low,
134        "volume" => PriceColumn::Volume,
135        _ => PriceColumn::Close,
136    };
137    Ok(Box::new(Rsi::new(RsiParams { period, column })))
138}
139
140// ── Tests ─────────────────────────────────────────────────────────────────────
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    fn make_candles(closes: &[f64]) -> Vec<Candle> {
147        closes.iter().enumerate().map(|(i, &c)| Candle {
148            time: i as i64, open: c, high: c, low: c, close: c, volume: 1.0,
149        }).collect()
150    }
151
152    #[test]
153    fn rsi_insufficient_data() {
154        let err = Rsi::with_period(14).calculate(&make_candles(&[1.0; 10])).unwrap_err();
155        assert!(matches!(err, IndicatorError::InsufficientData { .. }));
156    }
157
158    #[test]
159    fn rsi_leading_nans() {
160        let prices: Vec<f64> = (0..20).map(|i| i as f64).collect();
161        let out = Rsi::with_period(14).calculate(&make_candles(&prices)).unwrap();
162        let vals = out.get("RSI_14").unwrap();
163        for i in 0..14 {
164            assert!(vals[i].is_nan(), "expected NaN at [{i}], got {}", vals[i]);
165        }
166        assert!(!vals[14].is_nan());
167    }
168
169    #[test]
170    fn rsi_constant_gains_is_100() {
171        // All deltas positive → avg_loss=0, avg_gain>0 → RSI=100.
172        let prices: Vec<f64> = (0..20).map(|i| i as f64).collect();
173        let out = Rsi::with_period(14).calculate(&make_candles(&prices)).unwrap();
174        for &v in out.get("RSI_14").unwrap().iter().filter(|v| !v.is_nan()) {
175            assert!((v - 100.0).abs() < 1e-9, "expected 100.0, got {v}");
176        }
177    }
178
179    #[test]
180    fn rsi_constant_losses_is_0() {
181        // All deltas negative → avg_gain=0, avg_loss>0 → RSI=0.
182        let prices: Vec<f64> = (0..20).map(|i| 100.0 - i as f64).collect();
183        let out = Rsi::with_period(14).calculate(&make_candles(&prices)).unwrap();
184        for &v in out.get("RSI_14").unwrap().iter().filter(|v| !v.is_nan()) {
185            assert!(v.abs() < 1e-9, "expected 0.0, got {v}");
186        }
187    }
188
189    #[test]
190    fn rsi_alternating_equal_moves_is_50() {
191        // +1, -1, +1, -1 ... with 14 deltas: 7×(+1) and 7×(−1).
192        // avg_gain = 7/14 = 0.5, avg_loss = 7/14 = 0.5 → RSI = 50 exactly.
193        let mut prices = vec![100.0_f64];
194        for i in 0..19 {
195            let last = *prices.last().unwrap();
196            prices.push(if i % 2 == 0 { last + 1.0 } else { last - 1.0 });
197        }
198        let out = Rsi::with_period(14).calculate(&make_candles(&prices)).unwrap();
199        assert!((out.get("RSI_14").unwrap()[14] - 50.0).abs() < 1e-9);
200    }
201
202    #[test]
203    fn rsi_known_seed_value() {
204        // period=3, prices=[10, 11, 9, 11].
205        // Deltas: +1, -2, +2.
206        // avg_gain=(1+0+2)/3=1.0, avg_loss=(0+2+0)/3=0.667
207        // RSI[3] = 100 - 100/(1 + 1.0/(2/3)) = 100 - 100/2.5 = 60.0
208        let out = Rsi::with_period(3)
209            .calculate(&make_candles(&[10.0, 11.0, 9.0, 11.0]))
210            .unwrap();
211        assert!((out.get("RSI_3").unwrap()[3] - 60.0).abs() < 1e-6);
212    }
213
214    #[test]
215    fn rsi_wilder_smoothing_step() {
216        // Extend by one bar: prices=[10, 11, 9, 11, 10], delta[4]=-1.
217        // After seed: avg_gain=1.0, avg_loss=2/3.
218        // Wilder: avg_gain=(1.0*2+0)/3=2/3, avg_loss=(2/3*2+1)/3=7/9
219        let out = Rsi::with_period(3)
220            .calculate(&make_candles(&[10.0, 11.0, 9.0, 11.0, 10.0]))
221            .unwrap();
222        let ag = (1.0_f64 * 2.0) / 3.0;
223        let al = (2.0_f64 / 3.0 * 2.0 + 1.0) / 3.0;
224        let expected = 100.0 - 100.0 / (1.0 + ag / al);
225        assert!((out.get("RSI_3").unwrap()[4] - expected).abs() < 1e-9);
226    }
227
228    #[test]
229    fn rsi_stays_in_range() {
230        let prices: Vec<f64> = (0..50).map(|i| 100.0 + (i as f64 * 0.3).sin() * 10.0).collect();
231        let out = Rsi::with_period(14).calculate(&make_candles(&prices)).unwrap();
232        for &v in out.get("RSI_14").unwrap() {
233            if !v.is_nan() { assert!(v >= 0.0 && v <= 100.0, "out of range: {v}"); }
234        }
235    }
236
237    #[test]
238    fn factory_creates_rsi() {
239        let ind = factory(&HashMap::new()).unwrap();
240        assert_eq!(ind.name(), "RSI");
241        assert_eq!(ind.required_len(), 15);
242    }
243}