indicators/momentum/
rsi.rs1use std::collections::HashMap;
18
19use crate::error::IndicatorError;
20use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
21use crate::registry::{param_str, param_usize};
22use crate::types::Candle;
23
24#[derive(Debug, Clone)]
27pub struct RsiParams {
28 pub period: usize,
30 pub column: PriceColumn,
32}
33
34impl Default for RsiParams {
35 fn default() -> Self {
36 Self {
37 period: 14,
38 column: PriceColumn::Close,
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
46pub struct Rsi {
47 pub params: RsiParams,
48}
49
50impl Rsi {
51 pub fn new(params: RsiParams) -> Self {
52 Self { params }
53 }
54 pub fn with_period(period: usize) -> Self {
55 Self::new(RsiParams {
56 period,
57 ..Default::default()
58 })
59 }
60 fn output_key(&self) -> String {
61 format!("RSI_{}", self.params.period)
62 }
63}
64
65impl Indicator for Rsi {
66 fn name(&self) -> &str {
67 "RSI"
68 }
69
70 fn required_len(&self) -> usize {
72 self.params.period + 1
73 }
74
75 fn required_columns(&self) -> &[&'static str] {
76 &["close"]
77 }
78
79 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
80 self.check_len(candles)?;
81
82 let prices = self.params.column.extract(candles);
83 let n = prices.len();
84 let p = self.params.period;
85 let mut values = vec![f64::NAN; n];
86
87 let mut avg_gain = 0.0_f64;
89 let mut avg_loss = 0.0_f64;
90 for i in 1..=p {
91 let delta = prices[i] - prices[i - 1];
92 if delta > 0.0 {
93 avg_gain += delta;
94 } else {
95 avg_loss += -delta;
96 }
97 }
98 avg_gain /= p as f64;
99 avg_loss /= p as f64;
100 values[p] = rsi_from(avg_gain, avg_loss);
101
102 let w = (p - 1) as f64;
104 for i in (p + 1)..n {
105 let delta = prices[i] - prices[i - 1];
106 let gain = if delta > 0.0 { delta } else { 0.0 };
107 let loss = if delta < 0.0 { -delta } else { 0.0 };
108 avg_gain = (avg_gain * w + gain) / p as f64;
109 avg_loss = (avg_loss * w + loss) / p as f64;
110 values[i] = rsi_from(avg_gain, avg_loss);
111 }
112
113 Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
114 }
115}
116
117#[inline]
118fn rsi_from(avg_gain: f64, avg_loss: f64) -> f64 {
119 if avg_loss == 0.0 {
120 if avg_gain == 0.0 { 50.0 } else { 100.0 }
121 } else {
122 100.0 - 100.0 / (1.0 + avg_gain / avg_loss)
123 }
124}
125
126pub fn factory(params: &HashMap<String, String>) -> Result<Box<dyn Indicator>, IndicatorError> {
129 let period = param_usize(params, "period", 14)?;
130 let column = match param_str(params, "column", "close") {
131 "open" => PriceColumn::Open,
132 "high" => PriceColumn::High,
133 "low" => PriceColumn::Low,
134 "volume" => PriceColumn::Volume,
135 _ => PriceColumn::Close,
136 };
137 Ok(Box::new(Rsi::new(RsiParams { period, column })))
138}
139
140#[cfg(test)]
143mod tests {
144 use super::*;
145
146 fn make_candles(closes: &[f64]) -> Vec<Candle> {
147 closes.iter().enumerate().map(|(i, &c)| Candle {
148 time: i as i64, open: c, high: c, low: c, close: c, volume: 1.0,
149 }).collect()
150 }
151
152 #[test]
153 fn rsi_insufficient_data() {
154 let err = Rsi::with_period(14).calculate(&make_candles(&[1.0; 10])).unwrap_err();
155 assert!(matches!(err, IndicatorError::InsufficientData { .. }));
156 }
157
158 #[test]
159 fn rsi_leading_nans() {
160 let prices: Vec<f64> = (0..20).map(|i| i as f64).collect();
161 let out = Rsi::with_period(14).calculate(&make_candles(&prices)).unwrap();
162 let vals = out.get("RSI_14").unwrap();
163 for i in 0..14 {
164 assert!(vals[i].is_nan(), "expected NaN at [{i}], got {}", vals[i]);
165 }
166 assert!(!vals[14].is_nan());
167 }
168
169 #[test]
170 fn rsi_constant_gains_is_100() {
171 let prices: Vec<f64> = (0..20).map(|i| i as f64).collect();
173 let out = Rsi::with_period(14).calculate(&make_candles(&prices)).unwrap();
174 for &v in out.get("RSI_14").unwrap().iter().filter(|v| !v.is_nan()) {
175 assert!((v - 100.0).abs() < 1e-9, "expected 100.0, got {v}");
176 }
177 }
178
179 #[test]
180 fn rsi_constant_losses_is_0() {
181 let prices: Vec<f64> = (0..20).map(|i| 100.0 - i as f64).collect();
183 let out = Rsi::with_period(14).calculate(&make_candles(&prices)).unwrap();
184 for &v in out.get("RSI_14").unwrap().iter().filter(|v| !v.is_nan()) {
185 assert!(v.abs() < 1e-9, "expected 0.0, got {v}");
186 }
187 }
188
189 #[test]
190 fn rsi_alternating_equal_moves_is_50() {
191 let mut prices = vec![100.0_f64];
194 for i in 0..19 {
195 let last = *prices.last().unwrap();
196 prices.push(if i % 2 == 0 { last + 1.0 } else { last - 1.0 });
197 }
198 let out = Rsi::with_period(14).calculate(&make_candles(&prices)).unwrap();
199 assert!((out.get("RSI_14").unwrap()[14] - 50.0).abs() < 1e-9);
200 }
201
202 #[test]
203 fn rsi_known_seed_value() {
204 let out = Rsi::with_period(3)
209 .calculate(&make_candles(&[10.0, 11.0, 9.0, 11.0]))
210 .unwrap();
211 assert!((out.get("RSI_3").unwrap()[3] - 60.0).abs() < 1e-6);
212 }
213
214 #[test]
215 fn rsi_wilder_smoothing_step() {
216 let out = Rsi::with_period(3)
220 .calculate(&make_candles(&[10.0, 11.0, 9.0, 11.0, 10.0]))
221 .unwrap();
222 let ag = (1.0_f64 * 2.0) / 3.0;
223 let al = (2.0_f64 / 3.0 * 2.0 + 1.0) / 3.0;
224 let expected = 100.0 - 100.0 / (1.0 + ag / al);
225 assert!((out.get("RSI_3").unwrap()[4] - expected).abs() < 1e-9);
226 }
227
228 #[test]
229 fn rsi_stays_in_range() {
230 let prices: Vec<f64> = (0..50).map(|i| 100.0 + (i as f64 * 0.3).sin() * 10.0).collect();
231 let out = Rsi::with_period(14).calculate(&make_candles(&prices)).unwrap();
232 for &v in out.get("RSI_14").unwrap() {
233 if !v.is_nan() { assert!(v >= 0.0 && v <= 100.0, "out of range: {v}"); }
234 }
235 }
236
237 #[test]
238 fn factory_creates_rsi() {
239 let ind = factory(&HashMap::new()).unwrap();
240 assert_eq!(ind.name(), "RSI");
241 assert_eq!(ind.required_len(), 15);
242 }
243}