use std::collections::{HashMap, VecDeque};
use super::types::TrendDirection;
use crate::error::IndicatorError;
use crate::indicator::{Indicator, IndicatorOutput};
use crate::registry::param_usize;
use crate::types::Candle;
#[derive(Debug, Clone)]
pub struct AdxIndicator {
pub period: usize,
}
impl AdxIndicator {
pub fn new(period: usize) -> Self {
Self { period }
}
}
impl Indicator for AdxIndicator {
fn name(&self) -> &'static str {
"ADX"
}
fn required_len(&self) -> usize {
self.period * 2
}
fn required_columns(&self) -> &[&'static str] {
&["high", "low", "close"]
}
fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
self.check_len(candles)?;
let mut adx_calc = ADX::new(self.period);
let n = candles.len();
let mut adx_out = vec![f64::NAN; n];
let mut dip_out = vec![f64::NAN; n];
let mut dmi_out = vec![f64::NAN; n];
for (i, c) in candles.iter().enumerate() {
if let Some(v) = adx_calc.update(c.high, c.low, c.close) {
adx_out[i] = v;
dip_out[i] = adx_calc.di_plus().unwrap_or(f64::NAN);
dmi_out[i] = adx_calc.di_minus().unwrap_or(f64::NAN);
}
}
Ok(IndicatorOutput::from_pairs([
("adx", adx_out),
("di_plus", dip_out),
("di_minus", dmi_out),
]))
}
}
#[derive(Debug, Clone)]
pub struct AtrPrimIndicator {
pub period: usize,
}
impl AtrPrimIndicator {
pub fn new(period: usize) -> Self {
Self { period }
}
}
impl Indicator for AtrPrimIndicator {
fn name(&self) -> &'static str {
"AtrPrim"
}
fn required_len(&self) -> usize {
self.period + 1
}
fn required_columns(&self) -> &[&'static str] {
&["high", "low", "close"]
}
fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
self.check_len(candles)?;
let mut atr_calc = ATR::new(self.period);
let n = candles.len();
let mut out = vec![f64::NAN; n];
for (i, c) in candles.iter().enumerate() {
if let Some(v) = atr_calc.update(c.high, c.low, c.close) {
out[i] = v;
}
}
Ok(IndicatorOutput::from_pairs([("atr_prim", out)]))
}
}
#[derive(Debug, Clone)]
pub struct EmaPrimIndicator {
pub period: usize,
}
impl EmaPrimIndicator {
pub fn new(period: usize) -> Self {
Self { period }
}
}
impl Indicator for EmaPrimIndicator {
fn name(&self) -> &'static str {
"EmaPrim"
}
fn required_len(&self) -> usize {
self.period
}
fn required_columns(&self) -> &[&'static str] {
&["close"]
}
fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
self.check_len(candles)?;
let mut ema_calc = EMA::new(self.period);
let n = candles.len();
let mut out = vec![f64::NAN; n];
for (i, c) in candles.iter().enumerate() {
if let Some(v) = ema_calc.update(c.close) {
out[i] = v;
}
}
Ok(IndicatorOutput::from_pairs([("ema_prim", out)]))
}
}
#[derive(Debug, Clone)]
pub struct RsiPrimIndicator {
pub period: usize,
}
impl RsiPrimIndicator {
pub fn new(period: usize) -> Self {
Self { period }
}
}
impl Indicator for RsiPrimIndicator {
fn name(&self) -> &'static str {
"RsiPrim"
}
fn required_len(&self) -> usize {
self.period + 1
}
fn required_columns(&self) -> &[&'static str] {
&["close"]
}
fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
self.check_len(candles)?;
let mut rsi_calc = RSI::new(self.period);
let n = candles.len();
let mut out = vec![f64::NAN; n];
for (i, c) in candles.iter().enumerate() {
if let Some(v) = rsi_calc.update(c.close) {
out[i] = v;
}
}
Ok(IndicatorOutput::from_pairs([("rsi_prim", out)]))
}
}
#[derive(Debug, Clone)]
pub struct BbPrimIndicator {
pub period: usize,
pub std_dev: f64,
}
impl BbPrimIndicator {
pub fn new(period: usize, std_dev: f64) -> Self {
Self { period, std_dev }
}
}
impl Indicator for BbPrimIndicator {
fn name(&self) -> &'static str {
"BbPrim"
}
fn required_len(&self) -> usize {
self.period
}
fn required_columns(&self) -> &[&'static str] {
&["close"]
}
fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
self.check_len(candles)?;
let mut bb = BollingerBands::new(self.period, self.std_dev);
let n = candles.len();
let mut upper = vec![f64::NAN; n];
let mut mid = vec![f64::NAN; n];
let mut lower = vec![f64::NAN; n];
let mut width = vec![f64::NAN; n];
for (i, c) in candles.iter().enumerate() {
if let Some(v) = bb.update(c.close) {
upper[i] = v.upper;
mid[i] = v.middle;
lower[i] = v.lower;
width[i] = v.width;
}
}
Ok(IndicatorOutput::from_pairs([
("bb_upper", upper),
("bb_mid", mid),
("bb_lower", lower),
("bb_width", width),
]))
}
}
pub fn factory<S: ::std::hash::BuildHasher>(params: &HashMap<String, String, S>) -> Result<Box<dyn Indicator>, IndicatorError> {
let period = param_usize(params, "period", 14)?;
Ok(Box::new(AdxIndicator::new(period)))
}
#[derive(Debug, Clone)]
pub struct EMA {
period: usize,
multiplier: f64,
current_value: Option<f64>,
initialized: bool,
warmup_count: usize,
}
impl EMA {
pub fn new(period: usize) -> Self {
let multiplier = 2.0 / (period as f64 + 1.0);
Self {
period,
multiplier,
current_value: None,
initialized: false,
warmup_count: 0,
}
}
pub fn update(&mut self, price: f64) -> Option<f64> {
self.warmup_count += 1;
match self.current_value {
Some(prev_ema) => {
let new_ema = (price - prev_ema) * self.multiplier + prev_ema;
self.current_value = Some(new_ema);
if self.warmup_count >= self.period {
self.initialized = true;
}
}
None => {
self.current_value = Some(price);
}
}
if self.initialized {
self.current_value
} else {
None
}
}
pub fn value(&self) -> Option<f64> {
if self.initialized {
self.current_value
} else {
None
}
}
pub fn is_ready(&self) -> bool {
self.initialized
}
pub fn period(&self) -> usize {
self.period
}
pub fn reset(&mut self) {
self.current_value = None;
self.initialized = false;
self.warmup_count = 0;
}
}
#[derive(Debug, Clone)]
pub struct ATR {
period: usize,
values: VecDeque<f64>,
prev_close: Option<f64>,
current_atr: Option<f64>,
}
impl ATR {
pub fn new(period: usize) -> Self {
Self {
period,
values: VecDeque::with_capacity(period),
prev_close: None,
current_atr: None,
}
}
pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
let true_range = match self.prev_close {
Some(prev_c) => {
let hl = high - low;
let hc = (high - prev_c).abs();
let lc = (low - prev_c).abs();
hl.max(hc).max(lc)
}
None => high - low,
};
self.prev_close = Some(close);
self.values.push_back(true_range);
if self.values.len() > self.period {
self.values.pop_front();
}
if self.values.len() >= self.period {
if let Some(prev_atr) = self.current_atr {
let new_atr =
(prev_atr * (self.period - 1) as f64 + true_range) / self.period as f64;
self.current_atr = Some(new_atr);
} else {
let sum: f64 = self.values.iter().sum();
self.current_atr = Some(sum / self.period as f64);
}
}
self.current_atr
}
pub fn value(&self) -> Option<f64> {
self.current_atr
}
pub fn is_ready(&self) -> bool {
self.current_atr.is_some()
}
pub fn period(&self) -> usize {
self.period
}
pub fn reset(&mut self) {
self.values.clear();
self.prev_close = None;
self.current_atr = None;
}
}
#[derive(Debug, Clone)]
pub struct ADX {
period: usize,
atr: ATR,
plus_dm_ema: EMA,
minus_dm_ema: EMA,
dx_values: VecDeque<f64>,
prev_high: Option<f64>,
prev_low: Option<f64>,
current_adx: Option<f64>,
plus_dir_index: Option<f64>,
minus_dir_index: Option<f64>,
}
impl ADX {
pub fn new(period: usize) -> Self {
Self {
period,
atr: ATR::new(period),
plus_dm_ema: EMA::new(period),
minus_dm_ema: EMA::new(period),
dx_values: VecDeque::with_capacity(period),
prev_high: None,
prev_low: None,
current_adx: None,
plus_dir_index: None,
minus_dir_index: None,
}
}
pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
let (plus_dm, minus_dm) = match (self.prev_high, self.prev_low) {
(Some(prev_h), Some(prev_l)) => {
let up_move = high - prev_h;
let down_move = prev_l - low;
let plus = if up_move > down_move && up_move > 0.0 {
up_move
} else {
0.0
};
let minus = if down_move > up_move && down_move > 0.0 {
down_move
} else {
0.0
};
(plus, minus)
}
_ => (0.0, 0.0),
};
self.prev_high = Some(high);
self.prev_low = Some(low);
let atr = self.atr.update(high, low, close);
let smoothed_plus_dm = self.plus_dm_ema.update(plus_dm);
let smoothed_minus_dm = self.minus_dm_ema.update(minus_dm);
if let (Some(atr_val), Some(plus_dm_smooth), Some(minus_dm_smooth)) =
(atr, smoothed_plus_dm, smoothed_minus_dm)
&& atr_val > 0.0
{
let plus_dir_index = (plus_dm_smooth / atr_val) * 100.0;
let minus_dir_index = (minus_dm_smooth / atr_val) * 100.0;
self.plus_dir_index = Some(plus_dir_index);
self.minus_dir_index = Some(minus_dir_index);
let di_sum = plus_dir_index + minus_dir_index;
if di_sum > 0.0 {
let di_diff = (plus_dir_index - minus_dir_index).abs();
let dx = (di_diff / di_sum) * 100.0;
self.dx_values.push_back(dx);
if self.dx_values.len() > self.period {
self.dx_values.pop_front();
}
if self.dx_values.len() >= self.period {
if let Some(prev_adx) = self.current_adx {
let new_adx =
(prev_adx * (self.period - 1) as f64 + dx) / self.period as f64;
self.current_adx = Some(new_adx);
} else {
let sum: f64 = self.dx_values.iter().sum();
self.current_adx = Some(sum / self.period as f64);
}
}
}
}
self.current_adx
}
pub fn value(&self) -> Option<f64> {
self.current_adx
}
pub fn plus_dir_index(&self) -> Option<f64> {
self.plus_dir_index
}
pub fn minus_dir_index(&self) -> Option<f64> {
self.minus_dir_index
}
pub fn trend_direction(&self) -> Option<TrendDirection> {
match (self.plus_dir_index, self.minus_dir_index) {
(Some(plus), Some(minus)) => {
if plus > minus {
Some(TrendDirection::Bullish)
} else {
Some(TrendDirection::Bearish)
}
}
_ => None,
}
}
pub fn is_ready(&self) -> bool {
self.current_adx.is_some()
}
pub fn period(&self) -> usize {
self.period
}
pub fn di_plus(&self) -> Option<f64> {
self.plus_dir_index
}
pub fn di_minus(&self) -> Option<f64> {
self.minus_dir_index
}
pub fn reset(&mut self) {
self.atr.reset();
self.plus_dm_ema.reset();
self.minus_dm_ema.reset();
self.dx_values.clear();
self.prev_high = None;
self.prev_low = None;
self.current_adx = None;
self.plus_dir_index = None;
self.minus_dir_index = None;
}
}
#[derive(Debug, Clone, Copy)]
pub struct BollingerBandsValues {
pub upper: f64,
pub middle: f64,
pub lower: f64,
pub width: f64,
pub width_percentile: f64,
pub percent_b: f64,
pub std_dev: f64,
}
impl BollingerBandsValues {
pub fn is_overbought(&self) -> bool {
self.percent_b >= 0.95
}
pub fn is_oversold(&self) -> bool {
self.percent_b <= 0.05
}
pub fn is_high_volatility(&self, threshold_percentile: f64) -> bool {
self.width_percentile >= threshold_percentile
}
pub fn is_squeeze(&self, threshold_percentile: f64) -> bool {
self.width_percentile <= threshold_percentile
}
}
#[derive(Debug, Clone)]
pub struct BollingerBands {
period: usize,
std_dev_multiplier: f64,
prices: VecDeque<f64>,
width_history: VecDeque<f64>,
width_history_size: usize,
}
impl BollingerBands {
pub fn new(period: usize, std_dev_multiplier: f64) -> Self {
Self {
period,
std_dev_multiplier,
prices: VecDeque::with_capacity(period),
width_history: VecDeque::with_capacity(100),
width_history_size: 100, }
}
pub fn update(&mut self, price: f64) -> Option<BollingerBandsValues> {
self.prices.push_back(price);
if self.prices.len() > self.period {
self.prices.pop_front();
}
if self.prices.len() < self.period {
return None;
}
let sum: f64 = self.prices.iter().sum();
let sma = sum / self.period as f64;
let variance: f64 =
self.prices.iter().map(|p| (p - sma).powi(2)).sum::<f64>() / self.period as f64;
let std_dev = variance.sqrt();
let upper = sma + (std_dev * self.std_dev_multiplier);
let lower = sma - (std_dev * self.std_dev_multiplier);
let width = if sma > 0.0 {
(upper - lower) / sma * 100.0 } else {
0.0
};
self.width_history.push_back(width);
if self.width_history.len() > self.width_history_size {
self.width_history.pop_front();
}
let width_percentile = self.calculate_width_percentile(width);
let percent_b = if upper - lower > 0.0 {
(price - lower) / (upper - lower)
} else {
0.5
};
Some(BollingerBandsValues {
upper,
middle: sma,
lower,
width,
width_percentile,
percent_b,
std_dev,
})
}
fn calculate_width_percentile(&self, current_width: f64) -> f64 {
if self.width_history.len() < 10 {
return 50.0; }
let count_below = self
.width_history
.iter()
.filter(|&&w| w < current_width)
.count();
(count_below as f64 / self.width_history.len() as f64) * 100.0
}
pub fn is_ready(&self) -> bool {
self.prices.len() >= self.period
}
pub fn period(&self) -> usize {
self.period
}
pub fn std_dev_multiplier(&self) -> f64 {
self.std_dev_multiplier
}
pub fn reset(&mut self) {
self.prices.clear();
self.width_history.clear();
}
}
#[derive(Debug, Clone)]
pub struct RSI {
period: usize,
gains: EMA,
losses: EMA,
prev_close: Option<f64>,
last_rsi: Option<f64>,
}
impl RSI {
pub fn new(period: usize) -> Self {
Self {
period,
gains: EMA::new(period),
losses: EMA::new(period),
prev_close: None,
last_rsi: None,
}
}
pub fn update(&mut self, close: f64) -> Option<f64> {
if let Some(prev) = self.prev_close {
let change = close - prev;
let gain = if change > 0.0 { change } else { 0.0 };
let loss = if change < 0.0 { -change } else { 0.0 };
if let (Some(avg_gain), Some(avg_loss)) =
(self.gains.update(gain), self.losses.update(loss))
{
self.prev_close = Some(close);
let rsi = if avg_loss == 0.0 {
100.0
} else {
let rs = avg_gain / avg_loss;
100.0 - (100.0 / (1.0 + rs))
};
self.last_rsi = Some(rsi);
return self.last_rsi;
}
}
self.prev_close = Some(close);
None
}
pub fn value(&self) -> Option<f64> {
self.last_rsi
}
pub fn is_ready(&self) -> bool {
self.gains.is_ready() && self.losses.is_ready()
}
pub fn period(&self) -> usize {
self.period
}
pub fn reset(&mut self) {
self.gains.reset();
self.losses.reset();
self.prev_close = None;
self.last_rsi = None;
}
}
pub fn calculate_sma(prices: &[f64]) -> f64 {
if prices.is_empty() {
return 0.0;
}
prices.iter().sum::<f64>() / prices.len() as f64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ema_creation() {
let ema = EMA::new(10);
assert_eq!(ema.period(), 10);
assert!(!ema.is_ready());
assert!(ema.value().is_none());
}
#[test]
fn test_ema_warmup() {
let mut ema = EMA::new(10);
for i in 1..10 {
let result = ema.update(i as f64 * 10.0);
assert!(result.is_none(), "Should be None during warmup at step {i}");
}
let result = ema.update(100.0);
assert!(result.is_some(), "Should be ready after {0} updates", 10);
assert!(ema.is_ready());
}
#[test]
fn test_ema_calculation() {
let mut ema = EMA::new(10);
for i in 1..=10 {
ema.update(i as f64 * 10.0);
}
assert!(ema.is_ready());
let value = ema.value().unwrap();
assert!(value > 10.0 && value <= 100.0);
}
#[test]
fn test_ema_tracks_trend() {
let mut ema = EMA::new(5);
for _ in 0..5 {
ema.update(100.0);
}
let stable = ema.value().unwrap();
for _ in 0..10 {
ema.update(110.0);
}
let after_up = ema.value().unwrap();
assert!(after_up > stable, "EMA should increase with rising prices");
}
#[test]
fn test_ema_reset() {
let mut ema = EMA::new(5);
for _ in 0..10 {
ema.update(100.0);
}
assert!(ema.is_ready());
ema.reset();
assert!(!ema.is_ready());
assert!(ema.value().is_none());
}
#[test]
fn test_atr_creation() {
let atr = ATR::new(14);
assert_eq!(atr.period(), 14);
assert!(!atr.is_ready());
}
#[test]
fn test_atr_warmup() {
let mut atr = ATR::new(14);
for i in 1..=14 {
let base = 100.0 + i as f64;
let result = atr.update(base + 1.0, base - 1.0, base);
if i < 14 {
assert!(result.is_none());
}
}
assert!(atr.is_ready());
}
#[test]
fn test_atr_increases_with_volatility() {
let mut atr = ATR::new(14);
for i in 1..=14 {
let base = 100.0 + i as f64 * 0.1;
atr.update(base + 0.5, base - 0.5, base);
}
let low_vol_atr = atr.value().unwrap();
for i in 0..20 {
let base = 100.0 + if i % 2 == 0 { 5.0 } else { -5.0 };
atr.update(base + 3.0, base - 3.0, base);
}
let high_vol_atr = atr.value().unwrap();
assert!(
high_vol_atr > low_vol_atr,
"ATR should increase with volatility: {high_vol_atr} vs {low_vol_atr}"
);
}
#[test]
fn test_atr_reset() {
let mut atr = ATR::new(14);
for i in 0..20 {
let base = 100.0 + i as f64;
atr.update(base + 1.0, base - 1.0, base);
}
assert!(atr.is_ready());
atr.reset();
assert!(!atr.is_ready());
assert!(atr.value().is_none());
}
#[test]
fn test_adx_creation() {
let adx = ADX::new(14);
assert_eq!(adx.period(), 14);
assert!(!adx.is_ready());
}
#[test]
fn test_adx_trending_detection() {
let mut adx = ADX::new(14);
for i in 1..=50 {
let high = 100.0 + i as f64 * 2.0;
let low = 100.0 + i as f64 * 2.0 - 1.0;
let close = 100.0 + i as f64 * 2.0 - 0.5;
adx.update(high, low, close);
}
if let Some(adx_value) = adx.value() {
assert!(
adx_value > 20.0,
"ADX should indicate trend in strong uptrend: {adx_value}"
);
}
}
#[test]
fn test_adx_trend_direction() {
let mut adx = ADX::new(14);
for i in 1..=50 {
let high = 100.0 + i as f64 * 2.0;
let low = 100.0 + i as f64 * 2.0 - 1.0;
let close = 100.0 + i as f64 * 2.0 - 0.5;
adx.update(high, low, close);
}
if let Some(dir) = adx.trend_direction() {
assert_eq!(
dir,
TrendDirection::Bullish,
"Should detect bullish direction in uptrend"
);
}
}
#[test]
fn test_adx_di_values() {
let mut adx = ADX::new(14);
for i in 1..=50 {
let high = 100.0 + i as f64 * 2.0;
let low = 100.0 + i as f64 * 2.0 - 1.0;
let close = 100.0 + i as f64 * 2.0 - 0.5;
adx.update(high, low, close);
}
if let (Some(plus), Some(minus)) = (adx.plus_dir_index(), adx.minus_dir_index()) {
assert!(
plus > minus,
"+DI ({plus}) should be > -DI ({minus}) in uptrend"
);
}
}
#[test]
fn test_adx_reset() {
let mut adx = ADX::new(14);
for i in 1..=50 {
let base = 100.0 + i as f64;
adx.update(base + 1.0, base - 1.0, base);
}
assert!(adx.is_ready());
adx.reset();
assert!(!adx.is_ready());
assert!(adx.value().is_none());
assert!(adx.plus_dir_index().is_none());
assert!(adx.minus_dir_index().is_none());
}
#[test]
fn test_bb_creation() {
let bb = BollingerBands::new(20, 2.0);
assert_eq!(bb.period(), 20);
assert_eq!(bb.std_dev_multiplier(), 2.0);
assert!(!bb.is_ready());
}
#[test]
fn test_bb_warmup() {
let mut bb = BollingerBands::new(20, 2.0);
for i in 1..20 {
let result = bb.update(100.0 + i as f64 * 0.1);
assert!(result.is_none());
}
let result = bb.update(102.0);
assert!(result.is_some());
assert!(bb.is_ready());
}
#[test]
fn test_bb_band_ordering() {
let mut bb = BollingerBands::new(20, 2.0);
for i in 1..=25 {
let price = 100.0 + (i as f64 % 5.0);
bb.update(price);
}
let result = bb.update(102.0).unwrap();
assert!(
result.upper > result.middle,
"Upper band ({}) should be > middle ({})",
result.upper,
result.middle
);
assert!(
result.middle > result.lower,
"Middle ({}) should be > lower ({})",
result.middle,
result.lower
);
}
#[test]
fn test_bb_percent_b() {
let mut bb = BollingerBands::new(20, 2.0);
for i in 1..=20 {
bb.update(100.0 + (i as f64 % 3.0));
}
let values = bb.update(100.0 + 1.0);
if let Some(v) = values {
assert!(
v.percent_b >= 0.0 && v.percent_b <= 1.0,
"%B should be in [0,1]: {}",
v.percent_b
);
}
}
#[test]
fn test_bb_squeeze_detection() {
let mut bb = BollingerBands::new(20, 2.0);
for i in 0..50 {
let price = 100.0 + if i % 2 == 0 { 10.0 } else { -10.0 };
bb.update(price);
}
for _ in 0..50 {
bb.update(100.0);
}
let result = bb.update(100.0).unwrap();
assert!(
result.width_percentile < 50.0,
"Constant prices should produce low width percentile: {}",
result.width_percentile
);
}
#[test]
fn test_bb_overbought_oversold() {
let mut bb = BollingerBands::new(20, 2.0);
for _ in 0..20 {
bb.update(100.0);
}
let result = bb.update(110.0).unwrap();
assert!(
result.is_overbought(),
"Price far above bands should be overbought, %B = {}",
result.percent_b
);
}
#[test]
fn test_bb_reset() {
let mut bb = BollingerBands::new(20, 2.0);
for i in 0..25 {
bb.update(100.0 + i as f64);
}
assert!(bb.is_ready());
bb.reset();
assert!(!bb.is_ready());
}
#[test]
fn test_rsi_creation() {
let rsi = RSI::new(14);
assert_eq!(rsi.period(), 14);
assert!(!rsi.is_ready());
}
#[test]
fn test_rsi_bullish_market() {
let mut rsi = RSI::new(14);
let mut last_rsi = None;
for i in 0..30 {
let price = 100.0 + i as f64;
if let Some(val) = rsi.update(price) {
last_rsi = Some(val);
}
}
if let Some(val) = last_rsi {
assert!(
val > 50.0,
"RSI should be above 50 in bullish market: {val}"
);
}
}
#[test]
fn test_rsi_bearish_market() {
let mut rsi = RSI::new(14);
let mut last_rsi = None;
for i in 0..30 {
let price = 200.0 - i as f64;
if let Some(val) = rsi.update(price) {
last_rsi = Some(val);
}
}
if let Some(val) = last_rsi {
assert!(
val < 50.0,
"RSI should be below 50 in bearish market: {val}"
);
}
}
#[test]
fn test_rsi_range() {
let mut rsi = RSI::new(14);
for i in 0..50 {
let price = 100.0 + (i as f64 * 0.7).sin() * 10.0;
if let Some(val) = rsi.update(price) {
assert!(
(0.0..=100.0).contains(&val),
"RSI should be in [0, 100]: {val}"
);
}
}
}
#[test]
fn test_rsi_value_cached() {
let mut rsi = RSI::new(14);
assert!(
rsi.value().is_none(),
"value() should be None before warmup"
);
let mut last_from_update = None;
for i in 0..30 {
let price = 100.0 + i as f64;
if let Some(v) = rsi.update(price) {
last_from_update = Some(v);
}
}
assert_eq!(
rsi.value(),
last_from_update,
"value() must equal the last update() result"
);
}
#[test]
fn test_rsi_reset_clears_value() {
let mut rsi = RSI::new(14);
for i in 0..30 {
rsi.update(100.0 + i as f64);
}
assert!(rsi.value().is_some());
rsi.reset();
assert!(rsi.value().is_none(), "value() should be None after reset");
}
#[test]
fn test_calculate_sma() {
assert_eq!(calculate_sma(&[1.0, 2.0, 3.0, 4.0, 5.0]), 3.0);
assert_eq!(calculate_sma(&[100.0]), 100.0);
assert_eq!(calculate_sma(&[]), 0.0);
}
#[test]
fn test_calculate_sma_precision() {
let prices = vec![10.0, 20.0, 30.0];
let sma = calculate_sma(&prices);
assert!((sma - 20.0).abs() < f64::EPSILON);
}
}