use super::{IndicatorError, Result};
pub(crate) fn rsi_raw(data: &[f64], period: usize) -> Result<Vec<f64>> {
if period == 0 {
return Err(IndicatorError::InvalidPeriod(
"Period must be greater than 0".to_string(),
));
}
if data.len() <= period {
return Err(IndicatorError::InsufficientData {
need: period + 1,
got: data.len(),
});
}
let multiplier = 2.0 / (period as f64 + 1.0);
let period_f = period as f64;
let mut avg_gain = 0.0f64;
let mut avg_loss = 0.0f64;
for i in 1..=period {
let change = data[i] - data[i - 1];
if change > 0.0 {
avg_gain += change;
} else {
avg_loss += change.abs();
}
}
avg_gain /= period_f;
avg_loss /= period_f;
let n_valid = data.len() - period;
let mut result = Vec::with_capacity(n_valid);
result.push(if avg_loss == 0.0 {
100.0
} else {
100.0 - 100.0 / (1.0 + avg_gain / avg_loss)
});
for i in (period + 1)..data.len() {
let change = data[i] - data[i - 1];
let gain = if change > 0.0 { change } else { 0.0 };
let loss = if change < 0.0 { change.abs() } else { 0.0 };
avg_gain = (gain - avg_gain) * multiplier + avg_gain;
avg_loss = (loss - avg_loss) * multiplier + avg_loss;
result.push(if avg_loss == 0.0 {
100.0
} else {
100.0 - 100.0 / (1.0 + avg_gain / avg_loss)
});
}
Ok(result)
}
pub fn rsi(data: &[f64], period: usize) -> Result<Vec<Option<f64>>> {
let raw = rsi_raw(data, period)?;
let mut result = vec![None; data.len()];
for (k, v) in raw.into_iter().enumerate() {
result[k + period] = Some(v);
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rsi_basic() {
let data = vec![
44.0, 44.34, 44.09, 43.61, 44.33, 44.83, 45.10, 45.42, 45.84, 46.08, 45.89, 46.03,
45.61, 46.28, 46.28, 46.0,
];
let result = rsi(&data, 14).unwrap();
assert_eq!(result.len(), data.len());
for (i, &item) in result.iter().enumerate().take(14) {
assert_eq!(item, None, "Index {} should be None", i);
}
for (i, &val) in result.iter().enumerate().skip(14) {
if let Some(rsi_val) = val {
assert!(
(0.0..=100.0).contains(&rsi_val),
"RSI at index {} = {} is out of range [0, 100]",
i,
rsi_val
);
}
}
}
#[test]
fn test_rsi_all_gains() {
let data: Vec<f64> = (0..30).map(|x| x as f64).collect();
let result = rsi(&data, 14).unwrap();
if let Some(rsi_val) = result.last().and_then(|&v| v) {
assert!(rsi_val > 90.0, "RSI with all gains should be > 90");
}
}
#[test]
fn test_rsi_insufficient_data() {
let data = vec![1.0, 2.0, 3.0];
let result = rsi(&data, 14);
assert!(result.is_err());
}
}