Skip to main content

quantwave_core/indicators/incremental/
rsi.rs

1//! Native O(1) RSI — TA-Lib Wilder smoothing parity.
2
3use crate::traits::Next;
4
5#[inline]
6fn rsi_from_avgs(avg_gain: f64, avg_loss: f64) -> f64 {
7    if avg_loss == 0.0 {
8        100.0
9    } else {
10        let rs = avg_gain / avg_loss;
11        100.0 - (100.0 / (1.0 + rs))
12    }
13}
14
15/// Relative Strength Index (RSI) — Wilder smoothing, matches `talib_rs::momentum::rsi`.
16#[derive(Debug, Clone)]
17#[allow(non_camel_case_types)]
18pub struct RSI {
19    pub timeperiod: usize,
20    period_f: f64,
21    prev_close: Option<f64>,
22    avg_gain: f64,
23    avg_loss: f64,
24    warmup_changes: usize,
25    sum_gain: f64,
26    sum_loss: f64,
27}
28
29impl RSI {
30    pub fn new(timeperiod: usize) -> Self {
31        Self {
32            timeperiod,
33            period_f: timeperiod as f64,
34            prev_close: None,
35            avg_gain: 0.0,
36            avg_loss: 0.0,
37            warmup_changes: 0,
38            sum_gain: 0.0,
39            sum_loss: 0.0,
40        }
41    }
42}
43
44impl Next<f64> for RSI {
45    type Output = f64;
46
47    fn next(&mut self, input: f64) -> Self::Output {
48        let period = self.timeperiod;
49        if period < 2 {
50            return f64::NAN;
51        }
52
53        let Some(prev) = self.prev_close else {
54            self.prev_close = Some(input);
55            return f64::NAN;
56        };
57
58        let change = input - prev;
59        self.prev_close = Some(input);
60
61        let (gain, loss) = if change > 0.0 {
62            (change, 0.0)
63        } else {
64            (0.0, -change)
65        };
66
67        if self.warmup_changes < period {
68            self.warmup_changes += 1;
69            self.sum_gain += gain;
70            self.sum_loss += loss;
71            if self.warmup_changes < period {
72                return f64::NAN;
73            }
74            self.avg_gain = self.sum_gain / self.period_f;
75            self.avg_loss = self.sum_loss / self.period_f;
76            return rsi_from_avgs(self.avg_gain, self.avg_loss);
77        }
78
79        self.avg_gain =
80            (self.avg_gain * (self.period_f - 1.0) + gain) / self.period_f;
81        self.avg_loss =
82            (self.avg_loss * (self.period_f - 1.0) + loss) / self.period_f;
83        rsi_from_avgs(self.avg_gain, self.avg_loss)
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use proptest::prelude::*;
91
92    proptest! {
93        #[test]
94        fn test_rsi_parity(input in prop::collection::vec(0.1..100.0, 1..100)) {
95            let period = 14;
96            let mut rsi = RSI::new(period);
97            let streaming_results: Vec<f64> = input.iter().map(|&x| rsi.next(x)).collect();
98            let batch_results = talib_rs::momentum::rsi(&input, period)
99                .unwrap_or_else(|_| vec![f64::NAN; input.len()]);
100
101            for (s, b) in streaming_results.iter().zip(batch_results.iter()) {
102                if s.is_nan() {
103                    assert!(b.is_nan());
104                } else {
105                    approx::assert_relative_eq!(s, b, epsilon = 1e-6);
106                }
107            }
108        }
109    }
110}