indicators/momentum/
rsi.rs1use std::collections::HashMap;
18
19use crate::error::IndicatorError;
20use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
21use crate::registry::{param_str, param_usize};
22use crate::types::Candle;
23
24#[derive(Debug, Clone)]
27pub struct RsiParams {
28 pub period: usize,
30 pub column: PriceColumn,
32}
33
34impl Default for RsiParams {
35 fn default() -> Self {
36 Self {
37 period: 14,
38 column: PriceColumn::Close,
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
46pub struct Rsi {
47 pub params: RsiParams,
48}
49
50impl Rsi {
51 pub fn new(params: RsiParams) -> Self {
52 Self { params }
53 }
54 pub fn with_period(period: usize) -> Self {
55 Self::new(RsiParams {
56 period,
57 ..Default::default()
58 })
59 }
60 fn output_key(&self) -> String {
61 format!("RSI_{}", self.params.period)
62 }
63}
64
65impl Indicator for Rsi {
66 fn name(&self) -> &'static str {
67 "RSI"
68 }
69
70 fn required_len(&self) -> usize {
72 self.params.period + 1
73 }
74
75 fn required_columns(&self) -> &[&'static str] {
76 &["close"]
77 }
78
79 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
80 self.check_len(candles)?;
81
82 let prices = self.params.column.extract(candles);
83 let n = prices.len();
84 let p = self.params.period;
85 let mut values = vec![f64::NAN; n];
86
87 let mut avg_gain = 0.0_f64;
89 let mut avg_loss = 0.0_f64;
90 for i in 1..=p {
91 let delta = prices[i] - prices[i - 1];
92 if delta > 0.0 {
93 avg_gain += delta;
94 } else {
95 avg_loss += -delta;
96 }
97 }
98 avg_gain /= p as f64;
99 avg_loss /= p as f64;
100 values[p] = rsi_from(avg_gain, avg_loss);
101
102 let w = (p - 1) as f64;
104 for i in (p + 1)..n {
105 let delta = prices[i] - prices[i - 1];
106 let gain = if delta > 0.0 { delta } else { 0.0 };
107 let loss = if delta < 0.0 { -delta } else { 0.0 };
108 avg_gain = (avg_gain * w + gain) / p as f64;
109 avg_loss = (avg_loss * w + loss) / p as f64;
110 values[i] = rsi_from(avg_gain, avg_loss);
111 }
112
113 Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
114 }
115}
116
117#[inline]
118fn rsi_from(avg_gain: f64, avg_loss: f64) -> f64 {
119 if avg_loss == 0.0 {
120 if avg_gain == 0.0 { 50.0 } else { 100.0 }
121 } else {
122 100.0 - 100.0 / (1.0 + avg_gain / avg_loss)
123 }
124}
125
126pub fn factory<S: ::std::hash::BuildHasher>(
129 params: &HashMap<String, String, S>,
130) -> Result<Box<dyn Indicator>, IndicatorError> {
131 let period = param_usize(params, "period", 14)?;
132 let column = match param_str(params, "column", "close") {
133 "open" => PriceColumn::Open,
134 "high" => PriceColumn::High,
135 "low" => PriceColumn::Low,
136 "volume" => PriceColumn::Volume,
137 _ => PriceColumn::Close,
138 };
139 Ok(Box::new(Rsi::new(RsiParams { period, column })))
140}
141
142#[cfg(test)]
145mod tests {
146 use super::*;
147
148 fn make_candles(closes: &[f64]) -> Vec<Candle> {
149 closes
150 .iter()
151 .enumerate()
152 .map(|(i, &c)| Candle {
153 time: i64::try_from(i).expect("time index fits i64"),
154 open: c,
155 high: c,
156 low: c,
157 close: c,
158 volume: 1.0,
159 })
160 .collect()
161 }
162
163 #[test]
164 fn rsi_insufficient_data() {
165 let err = Rsi::with_period(14)
166 .calculate(&make_candles(&[1.0; 10]))
167 .unwrap_err();
168 assert!(matches!(err, IndicatorError::InsufficientData { .. }));
169 }
170
171 #[test]
172 fn rsi_leading_nans() {
173 let prices: Vec<f64> = (0..20).map(|i| i as f64).collect();
174 let out = Rsi::with_period(14)
175 .calculate(&make_candles(&prices))
176 .unwrap();
177 let vals = out.get("RSI_14").unwrap();
178 for (i, &v) in vals.iter().enumerate().take(14) {
179 assert!(v.is_nan(), "expected NaN at [{i}], got {v}");
180 }
181 assert!(!vals[14].is_nan());
182 }
183
184 #[test]
185 fn rsi_constant_gains_is_100() {
186 let prices: Vec<f64> = (0..20).map(|i| i as f64).collect();
188 let out = Rsi::with_period(14)
189 .calculate(&make_candles(&prices))
190 .unwrap();
191 for &v in out.get("RSI_14").unwrap().iter().filter(|v| !v.is_nan()) {
192 assert!((v - 100.0).abs() < 1e-9, "expected 100.0, got {v}");
193 }
194 }
195
196 #[test]
197 fn rsi_constant_losses_is_0() {
198 let prices: Vec<f64> = (0..20).map(|i| 100.0 - i as f64).collect();
200 let out = Rsi::with_period(14)
201 .calculate(&make_candles(&prices))
202 .unwrap();
203 for &v in out.get("RSI_14").unwrap().iter().filter(|v| !v.is_nan()) {
204 assert!(v.abs() < 1e-9, "expected 0.0, got {v}");
205 }
206 }
207
208 #[test]
209 fn rsi_alternating_equal_moves_is_50() {
210 let mut prices = vec![100.0_f64];
213 for i in 0..19 {
214 let last = *prices.last().unwrap();
215 prices.push(if i % 2 == 0 { last + 1.0 } else { last - 1.0 });
216 }
217 let out = Rsi::with_period(14)
218 .calculate(&make_candles(&prices))
219 .unwrap();
220 assert!((out.get("RSI_14").unwrap()[14] - 50.0).abs() < 1e-9);
221 }
222
223 #[test]
224 fn rsi_known_seed_value() {
225 let out = Rsi::with_period(3)
230 .calculate(&make_candles(&[10.0, 11.0, 9.0, 11.0]))
231 .unwrap();
232 assert!((out.get("RSI_3").unwrap()[3] - 60.0).abs() < 1e-6);
233 }
234
235 #[test]
236 fn rsi_wilder_smoothing_step() {
237 let out = Rsi::with_period(3)
241 .calculate(&make_candles(&[10.0, 11.0, 9.0, 11.0, 10.0]))
242 .unwrap();
243 let ag = (1.0_f64 * 2.0) / 3.0;
244 let al = (2.0_f64 / 3.0 * 2.0 + 1.0) / 3.0;
245 let expected = 100.0 - 100.0 / (1.0 + ag / al);
246 assert!((out.get("RSI_3").unwrap()[4] - expected).abs() < 1e-9);
247 }
248
249 #[test]
250 fn rsi_stays_in_range() {
251 let prices: Vec<f64> = (0..50)
252 .map(|i| 100.0 + (i as f64 * 0.3).sin() * 10.0)
253 .collect();
254 let out = Rsi::with_period(14)
255 .calculate(&make_candles(&prices))
256 .unwrap();
257 for &v in out.get("RSI_14").unwrap() {
258 if !v.is_nan() {
259 assert!((0.0..=100.0).contains(&v), "out of range: {v}");
260 }
261 }
262 }
263
264 #[test]
265 fn factory_creates_rsi() {
266 let ind = factory(&HashMap::new()).unwrap();
267 assert_eq!(ind.name(), "RSI");
268 assert_eq!(ind.required_len(), 15);
269 }
270}