1use quant_primitives::Candle;
4use rust_decimal::Decimal;
5
6use crate::error::IndicatorError;
7use crate::indicator::Indicator;
8use crate::series::Series;
9
10#[derive(Debug, Clone)]
40pub struct Rsi {
41 period: usize,
42 name: String,
43}
44
45impl Rsi {
46 pub fn new(period: usize) -> Result<Self, IndicatorError> {
54 if period == 0 {
55 return Err(IndicatorError::InvalidParameter {
56 message: "RSI period must be > 0".to_string(),
57 });
58 }
59 Ok(Self {
60 period,
61 name: format!("RSI({})", period),
62 })
63 }
64}
65
66impl Indicator for Rsi {
67 fn name(&self) -> &str {
68 &self.name
69 }
70
71 fn warmup_period(&self) -> usize {
72 self.period + 1
75 }
76
77 fn compute(&self, candles: &[Candle]) -> Result<Series, IndicatorError> {
78 let required = self.period + 1;
79 if candles.len() < required {
80 return Err(IndicatorError::InsufficientData {
81 required,
82 actual: candles.len(),
83 });
84 }
85
86 let changes: Vec<Decimal> = candles
88 .windows(2)
89 .map(|w| w[1].close() - w[0].close())
90 .collect();
91
92 let mut values = Vec::with_capacity(candles.len() - required + 1);
93 let period_dec = Decimal::from(self.period as u64);
94
95 let mut avg_gain = Decimal::ZERO;
97 let mut avg_loss = Decimal::ZERO;
98
99 for change in changes.iter().take(self.period) {
100 if *change > Decimal::ZERO {
101 avg_gain += *change;
102 } else {
103 avg_loss += change.abs();
104 }
105 }
106 avg_gain /= period_dec;
107 avg_loss /= period_dec;
108
109 let rsi = calculate_rsi(avg_gain, avg_loss);
111 let ts = candles[self.period].timestamp();
112 values.push((ts, rsi));
113
114 for (i, change) in changes.iter().enumerate().skip(self.period) {
116 let (gain, loss) = if *change > Decimal::ZERO {
117 (*change, Decimal::ZERO)
118 } else {
119 (Decimal::ZERO, change.abs())
120 };
121
122 avg_gain = (avg_gain * (period_dec - Decimal::ONE) + gain) / period_dec;
124 avg_loss = (avg_loss * (period_dec - Decimal::ONE) + loss) / period_dec;
125
126 let rsi = calculate_rsi(avg_gain, avg_loss);
127 let ts = candles[i + 1].timestamp();
128 values.push((ts, rsi));
129 }
130
131 Ok(Series::new(values))
132 }
133}
134
135fn calculate_rsi(avg_gain: Decimal, avg_loss: Decimal) -> Decimal {
137 if avg_loss == Decimal::ZERO {
138 if avg_gain == Decimal::ZERO {
139 Decimal::from(50)
141 } else {
142 Decimal::from(100)
144 }
145 } else {
146 let rs = avg_gain / avg_loss;
147 Decimal::from(100) - (Decimal::from(100) / (Decimal::ONE + rs))
148 }
149}
150
151#[cfg(test)]
152#[path = "rsi_tests.rs"]
153mod tests;