Skip to main content

indicators/momentum/
rsi.rs

1//! Relative Strength Index (RSI) — Wilder's smoothed method.
2//!
3//! Python source: `indicators/momentum/rsi.py :: class RSI`
4//!
5//! # Algorithm
6//!
7//! 1. `delta[i] = close[i] - close[i-1]`
8//! 2. **Seed** (bars 1..=period): simple mean of gains and losses.
9//! 3. **Wilder smoothing** (bar > period):
10//!    `avg_gain = (prev * (period-1) + gain) / period`
11//! 4. `RSI = 100 - 100 / (1 + avg_gain / avg_loss)`
12//!
13//! This matches TA-Lib and TradingView (Wilder seeding, not SMA).
14//!
15//! Output column: `"RSI_{period}"` — e.g. `"RSI_14"`.
16
17use std::collections::HashMap;
18
19use crate::error::IndicatorError;
20use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
21use crate::registry::{param_str, param_usize};
22use crate::types::Candle;
23
24// ── Params ────────────────────────────────────────────────────────────────────
25
26#[derive(Debug, Clone)]
27pub struct RsiParams {
28    /// Look-back period. Wilder's original default: 14.
29    pub period: usize,
30    /// Price field. Default: Close.
31    pub column: PriceColumn,
32}
33
34impl Default for RsiParams {
35    fn default() -> Self {
36        Self {
37            period: 14,
38            column: PriceColumn::Close,
39        }
40    }
41}
42
43// ── Indicator struct ──────────────────────────────────────────────────────────
44
45#[derive(Debug, Clone)]
46pub struct Rsi {
47    pub params: RsiParams,
48}
49
50impl Rsi {
51    pub fn new(params: RsiParams) -> Self {
52        Self { params }
53    }
54    pub fn with_period(period: usize) -> Self {
55        Self::new(RsiParams {
56            period,
57            ..Default::default()
58        })
59    }
60    fn output_key(&self) -> String {
61        format!("RSI_{}", self.params.period)
62    }
63}
64
65impl Indicator for Rsi {
66    fn name(&self) -> &'static str {
67        "RSI"
68    }
69
70    /// Need `period + 1` bars: `period` deltas to seed, output starts at index `period`.
71    fn required_len(&self) -> usize {
72        self.params.period + 1
73    }
74
75    fn required_columns(&self) -> &[&'static str] {
76        &["close"]
77    }
78
79    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
80        self.check_len(candles)?;
81
82        let prices = self.params.column.extract(candles);
83        let n = prices.len();
84        let p = self.params.period;
85        let mut values = vec![f64::NAN; n];
86
87        // ── Seed: SMA of first `p` deltas ────────────────────────────────────
88        let mut avg_gain = 0.0_f64;
89        let mut avg_loss = 0.0_f64;
90        for i in 1..=p {
91            let delta = prices[i] - prices[i - 1];
92            if delta > 0.0 {
93                avg_gain += delta;
94            } else {
95                avg_loss += -delta;
96            }
97        }
98        avg_gain /= p as f64;
99        avg_loss /= p as f64;
100        values[p] = rsi_from(avg_gain, avg_loss);
101
102        // ── Wilder smoothing for remaining bars ───────────────────────────────
103        let w = (p - 1) as f64;
104        for i in (p + 1)..n {
105            let delta = prices[i] - prices[i - 1];
106            let gain = if delta > 0.0 { delta } else { 0.0 };
107            let loss = if delta < 0.0 { -delta } else { 0.0 };
108            avg_gain = (avg_gain * w + gain) / p as f64;
109            avg_loss = (avg_loss * w + loss) / p as f64;
110            values[i] = rsi_from(avg_gain, avg_loss);
111        }
112
113        Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
114    }
115}
116
117#[inline]
118fn rsi_from(avg_gain: f64, avg_loss: f64) -> f64 {
119    if avg_loss == 0.0 {
120        if avg_gain == 0.0 { 50.0 } else { 100.0 }
121    } else {
122        100.0 - 100.0 / (1.0 + avg_gain / avg_loss)
123    }
124}
125
126// ── Registry factory ──────────────────────────────────────────────────────────
127
128pub fn factory<S: ::std::hash::BuildHasher>(
129    params: &HashMap<String, String, S>,
130) -> Result<Box<dyn Indicator>, IndicatorError> {
131    let period = param_usize(params, "period", 14)?;
132    let column = match param_str(params, "column", "close") {
133        "open" => PriceColumn::Open,
134        "high" => PriceColumn::High,
135        "low" => PriceColumn::Low,
136        "volume" => PriceColumn::Volume,
137        _ => PriceColumn::Close,
138    };
139    Ok(Box::new(Rsi::new(RsiParams { period, column })))
140}
141
142// ── Tests ─────────────────────────────────────────────────────────────────────
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    fn make_candles(closes: &[f64]) -> Vec<Candle> {
149        closes
150            .iter()
151            .enumerate()
152            .map(|(i, &c)| Candle {
153                time: i64::try_from(i).expect("time index fits i64"),
154                open: c,
155                high: c,
156                low: c,
157                close: c,
158                volume: 1.0,
159            })
160            .collect()
161    }
162
163    #[test]
164    fn rsi_insufficient_data() {
165        let err = Rsi::with_period(14)
166            .calculate(&make_candles(&[1.0; 10]))
167            .unwrap_err();
168        assert!(matches!(err, IndicatorError::InsufficientData { .. }));
169    }
170
171    #[test]
172    fn rsi_leading_nans() {
173        let prices: Vec<f64> = (0..20).map(|i| i as f64).collect();
174        let out = Rsi::with_period(14)
175            .calculate(&make_candles(&prices))
176            .unwrap();
177        let vals = out.get("RSI_14").unwrap();
178        for (i, &v) in vals.iter().enumerate().take(14) {
179            assert!(v.is_nan(), "expected NaN at [{i}], got {v}");
180        }
181        assert!(!vals[14].is_nan());
182    }
183
184    #[test]
185    fn rsi_constant_gains_is_100() {
186        // All deltas positive → avg_loss=0, avg_gain>0 → RSI=100.
187        let prices: Vec<f64> = (0..20).map(|i| i as f64).collect();
188        let out = Rsi::with_period(14)
189            .calculate(&make_candles(&prices))
190            .unwrap();
191        for &v in out.get("RSI_14").unwrap().iter().filter(|v| !v.is_nan()) {
192            assert!((v - 100.0).abs() < 1e-9, "expected 100.0, got {v}");
193        }
194    }
195
196    #[test]
197    fn rsi_constant_losses_is_0() {
198        // All deltas negative → avg_gain=0, avg_loss>0 → RSI=0.
199        let prices: Vec<f64> = (0..20).map(|i| 100.0 - i as f64).collect();
200        let out = Rsi::with_period(14)
201            .calculate(&make_candles(&prices))
202            .unwrap();
203        for &v in out.get("RSI_14").unwrap().iter().filter(|v| !v.is_nan()) {
204            assert!(v.abs() < 1e-9, "expected 0.0, got {v}");
205        }
206    }
207
208    #[test]
209    fn rsi_alternating_equal_moves_is_50() {
210        // +1, -1, +1, -1 ... with 14 deltas: 7×(+1) and 7×(−1).
211        // avg_gain = 7/14 = 0.5, avg_loss = 7/14 = 0.5 → RSI = 50 exactly.
212        let mut prices = vec![100.0_f64];
213        for i in 0..19 {
214            let last = *prices.last().unwrap();
215            prices.push(if i % 2 == 0 { last + 1.0 } else { last - 1.0 });
216        }
217        let out = Rsi::with_period(14)
218            .calculate(&make_candles(&prices))
219            .unwrap();
220        assert!((out.get("RSI_14").unwrap()[14] - 50.0).abs() < 1e-9);
221    }
222
223    #[test]
224    fn rsi_known_seed_value() {
225        // period=3, prices=[10, 11, 9, 11].
226        // Deltas: +1, -2, +2.
227        // avg_gain=(1+0+2)/3=1.0, avg_loss=(0+2+0)/3=0.667
228        // RSI[3] = 100 - 100/(1 + 1.0/(2/3)) = 100 - 100/2.5 = 60.0
229        let out = Rsi::with_period(3)
230            .calculate(&make_candles(&[10.0, 11.0, 9.0, 11.0]))
231            .unwrap();
232        assert!((out.get("RSI_3").unwrap()[3] - 60.0).abs() < 1e-6);
233    }
234
235    #[test]
236    fn rsi_wilder_smoothing_step() {
237        // Extend by one bar: prices=[10, 11, 9, 11, 10], delta[4]=-1.
238        // After seed: avg_gain=1.0, avg_loss=2/3.
239        // Wilder: avg_gain=(1.0*2+0)/3=2/3, avg_loss=(2/3*2+1)/3=7/9
240        let out = Rsi::with_period(3)
241            .calculate(&make_candles(&[10.0, 11.0, 9.0, 11.0, 10.0]))
242            .unwrap();
243        let ag = (1.0_f64 * 2.0) / 3.0;
244        let al = (2.0_f64 / 3.0 * 2.0 + 1.0) / 3.0;
245        let expected = 100.0 - 100.0 / (1.0 + ag / al);
246        assert!((out.get("RSI_3").unwrap()[4] - expected).abs() < 1e-9);
247    }
248
249    #[test]
250    fn rsi_stays_in_range() {
251        let prices: Vec<f64> = (0..50)
252            .map(|i| 100.0 + (i as f64 * 0.3).sin() * 10.0)
253            .collect();
254        let out = Rsi::with_period(14)
255            .calculate(&make_candles(&prices))
256            .unwrap();
257        for &v in out.get("RSI_14").unwrap() {
258            if !v.is_nan() {
259                assert!((0.0..=100.0).contains(&v), "out of range: {v}");
260            }
261        }
262    }
263
264    #[test]
265    fn factory_creates_rsi() {
266        let ind = factory(&HashMap::new()).unwrap();
267        assert_eq!(ind.name(), "RSI");
268        assert_eq!(ind.required_len(), 15);
269    }
270}