mantis_ta/indicators/momentum/
rsi.rs1use crate::indicators::Indicator;
2use crate::types::Candle;
3
4#[derive(Debug, Clone)]
31pub struct RSI {
32 period: usize,
33 prev_close: Option<f64>,
34 gain_sum: f64,
35 loss_sum: f64,
36 avg_gain: Option<f64>,
37 avg_loss: Option<f64>,
38 count: usize,
39}
40
41impl RSI {
42 pub fn new(period: usize) -> Self {
43 assert!(period > 0, "period must be > 0");
44 Self {
45 period,
46 prev_close: None,
47 gain_sum: 0.0,
48 loss_sum: 0.0,
49 avg_gain: None,
50 avg_loss: None,
51 count: 0,
52 }
53 }
54
55 #[inline]
56 fn compute_rsi(&self, avg_gain: f64, avg_loss: f64) -> f64 {
57 if avg_loss == 0.0 {
58 return 100.0;
59 }
60 if avg_gain == 0.0 {
61 return 0.0;
62 }
63 let rs = avg_gain / avg_loss;
64 100.0 - (100.0 / (1.0 + rs))
65 }
66}
67
68impl Indicator for RSI {
69 type Output = f64;
70
71 fn next(&mut self, candle: &Candle) -> Option<Self::Output> {
72 let close = candle.close;
73
74 let Some(prev) = self.prev_close else {
75 self.prev_close = Some(close);
76 return None;
77 };
78
79 let change = close - prev;
80 let gain = change.max(0.0);
81 let loss = (-change).max(0.0);
82
83 if let (Some(prev_avg_gain), Some(prev_avg_loss)) = (self.avg_gain, self.avg_loss) {
84 let new_avg_gain =
86 (prev_avg_gain * (self.period as f64 - 1.0) + gain) / self.period as f64;
87 let new_avg_loss =
88 (prev_avg_loss * (self.period as f64 - 1.0) + loss) / self.period as f64;
89
90 self.avg_gain = Some(new_avg_gain);
91 self.avg_loss = Some(new_avg_loss);
92 self.prev_close = Some(close);
93 return Some(self.compute_rsi(new_avg_gain, new_avg_loss));
94 }
95
96 self.gain_sum += gain;
98 self.loss_sum += loss;
99 self.count += 1;
100
101 if self.count == self.period {
102 let avg_gain = self.gain_sum / self.period as f64;
103 let avg_loss = self.loss_sum / self.period as f64;
104 self.avg_gain = Some(avg_gain);
105 self.avg_loss = Some(avg_loss);
106 self.prev_close = Some(close);
107 return Some(self.compute_rsi(avg_gain, avg_loss));
108 }
109
110 self.prev_close = Some(close);
111 None
112 }
113
114 fn reset(&mut self) {
115 self.prev_close = None;
116 self.gain_sum = 0.0;
117 self.loss_sum = 0.0;
118 self.avg_gain = None;
119 self.avg_loss = None;
120 self.count = 0;
121 }
122
123 fn warmup_period(&self) -> usize {
124 self.period + 1
125 }
126
127 fn clone_boxed(&self) -> Box<dyn Indicator<Output = Self::Output>> {
128 Box::new(self.clone())
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135
136 #[test]
137 fn rsi_emits_after_warmup() {
138 let mut rsi = RSI::new(3);
139 let prices = [1.0, 2.0, 3.0, 2.5, 2.0];
140 let candles: Vec<Candle> = prices
141 .iter()
142 .map(|p| Candle {
143 timestamp: 0,
144 open: *p,
145 high: *p,
146 low: *p,
147 close: *p,
148 volume: 0.0,
149 })
150 .collect();
151
152 let outputs: Vec<_> = candles.iter().map(|c| rsi.next(c)).collect();
153 let wp = rsi.warmup_period(); assert!(outputs.iter().take(wp - 1).all(|o| o.is_none()));
156 assert!(outputs[wp - 1].is_some());
158 }
159}