Skip to main content

mantis_ta/indicators/momentum/
rsi.rs

1use crate::indicators::Indicator;
2use crate::types::Candle;
3
4/// Relative Strength Index (RSI) using Wilder smoothing.
5///
6/// # Examples
7/// ```rust
8/// use mantis_ta::indicators::{Indicator, RSI};
9/// use mantis_ta::types::Candle;
10///
11/// let prices = [1.0, 2.0, 3.0, 2.5, 2.0, 2.2];
12/// let candles: Vec<Candle> = prices
13///     .iter()
14///     .enumerate()
15///     .map(|(i, p)| Candle {
16///         timestamp: i as i64,
17///         open: *p,
18///         high: *p,
19///         low: *p,
20///         close: *p,
21///         volume: 0.0,
22///     })
23///     .collect();
24///
25/// let out = RSI::new(3).calculate(&candles);
26/// // Warmup: period + 1 bars before first value
27/// assert!(out.iter().take(3).all(|v| v.is_none()));
28/// assert!(out[3].is_some());
29/// ```
30#[derive(Debug, Clone)]
31pub struct RSI {
32    period: usize,
33    prev_close: Option<f64>,
34    gain_sum: f64,
35    loss_sum: f64,
36    avg_gain: Option<f64>,
37    avg_loss: Option<f64>,
38    count: usize,
39}
40
41impl RSI {
42    pub fn new(period: usize) -> Self {
43        assert!(period > 0, "period must be > 0");
44        Self {
45            period,
46            prev_close: None,
47            gain_sum: 0.0,
48            loss_sum: 0.0,
49            avg_gain: None,
50            avg_loss: None,
51            count: 0,
52        }
53    }
54
55    #[inline]
56    fn compute_rsi(&self, avg_gain: f64, avg_loss: f64) -> f64 {
57        if avg_loss == 0.0 {
58            return 100.0;
59        }
60        if avg_gain == 0.0 {
61            return 0.0;
62        }
63        let rs = avg_gain / avg_loss;
64        100.0 - (100.0 / (1.0 + rs))
65    }
66}
67
68impl Indicator for RSI {
69    type Output = f64;
70
71    fn next(&mut self, candle: &Candle) -> Option<Self::Output> {
72        let close = candle.close;
73
74        let Some(prev) = self.prev_close else {
75            self.prev_close = Some(close);
76            return None;
77        };
78
79        let change = close - prev;
80        let gain = change.max(0.0);
81        let loss = (-change).max(0.0);
82
83        if let (Some(prev_avg_gain), Some(prev_avg_loss)) = (self.avg_gain, self.avg_loss) {
84            // Wilder smoothing
85            let new_avg_gain =
86                (prev_avg_gain * (self.period as f64 - 1.0) + gain) / self.period as f64;
87            let new_avg_loss =
88                (prev_avg_loss * (self.period as f64 - 1.0) + loss) / self.period as f64;
89
90            self.avg_gain = Some(new_avg_gain);
91            self.avg_loss = Some(new_avg_loss);
92            self.prev_close = Some(close);
93            return Some(self.compute_rsi(new_avg_gain, new_avg_loss));
94        }
95
96        // Warmup accumulation
97        self.gain_sum += gain;
98        self.loss_sum += loss;
99        self.count += 1;
100
101        if self.count == self.period {
102            let avg_gain = self.gain_sum / self.period as f64;
103            let avg_loss = self.loss_sum / self.period as f64;
104            self.avg_gain = Some(avg_gain);
105            self.avg_loss = Some(avg_loss);
106            self.prev_close = Some(close);
107            return Some(self.compute_rsi(avg_gain, avg_loss));
108        }
109
110        self.prev_close = Some(close);
111        None
112    }
113
114    fn reset(&mut self) {
115        self.prev_close = None;
116        self.gain_sum = 0.0;
117        self.loss_sum = 0.0;
118        self.avg_gain = None;
119        self.avg_loss = None;
120        self.count = 0;
121    }
122
123    fn warmup_period(&self) -> usize {
124        self.period + 1
125    }
126
127    fn clone_boxed(&self) -> Box<dyn Indicator<Output = Self::Output>> {
128        Box::new(self.clone())
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    #[test]
137    fn rsi_emits_after_warmup() {
138        let mut rsi = RSI::new(3);
139        let prices = [1.0, 2.0, 3.0, 2.5, 2.0];
140        let candles: Vec<Candle> = prices
141            .iter()
142            .map(|p| Candle {
143                timestamp: 0,
144                open: *p,
145                high: *p,
146                low: *p,
147                close: *p,
148                volume: 0.0,
149            })
150            .collect();
151
152        let outputs: Vec<_> = candles.iter().map(|c| rsi.next(c)).collect();
153        let wp = rsi.warmup_period(); // period + 1
154        // First wp-1 outputs should be None
155        assert!(outputs.iter().take(wp - 1).all(|o| o.is_none()));
156        // Output at index wp-1 should be the first Some
157        assert!(outputs[wp - 1].is_some());
158    }
159}