use std::collections::VecDeque;
use quant_primitives::Candle;
use rust_decimal::Decimal;
use crate::error::IndicatorError;
use crate::indicator::Indicator;
use crate::series::Series;
#[derive(Debug, Clone)]
pub struct Atr {
period: usize,
name: String,
}
impl Atr {
pub fn new(period: usize) -> Result<Self, IndicatorError> {
if period == 0 {
return Err(IndicatorError::InvalidParameter {
message: "ATR period must be > 0".to_string(),
});
}
Ok(Self {
period,
name: format!("ATR({})", period),
})
}
}
impl Indicator for Atr {
fn name(&self) -> &str {
&self.name
}
fn warmup_period(&self) -> usize {
self.period + 1
}
fn compute(&self, candles: &[Candle]) -> Result<Series, IndicatorError> {
let required = self.period + 1;
if candles.len() < required {
return Err(IndicatorError::InsufficientData {
required,
actual: candles.len(),
});
}
let mut true_ranges = Vec::with_capacity(candles.len() - 1);
true_ranges.push(candles[0].high() - candles[0].low());
for i in 1..candles.len() {
let tr = true_range(&candles[i], candles[i - 1].close());
true_ranges.push(tr);
}
let mut values = Vec::with_capacity(candles.len() - required + 1);
let period_dec = Decimal::from(self.period as u64);
let initial_sum: Decimal = true_ranges[..self.period].iter().sum();
let mut atr = initial_sum / period_dec;
let ts = candles[self.period - 1].timestamp();
values.push((ts, atr));
for (i, tr) in true_ranges.iter().enumerate().skip(self.period) {
atr = (atr * (period_dec - Decimal::ONE) + *tr) / period_dec;
let ts = candles[i].timestamp();
values.push((ts, atr));
}
Ok(Series::new(values))
}
}
pub fn true_range(candle: &Candle, prev_close: Decimal) -> Decimal {
let high_low = candle.high() - candle.low();
let high_prev = (candle.high() - prev_close).abs();
let low_prev = (candle.low() - prev_close).abs();
high_low.max(high_prev).max(low_prev)
}
pub fn rolling_atr_mean(window: &mut VecDeque<Decimal>, new_tr: Decimal, period: usize) -> Decimal {
window.push_back(new_tr);
if window.len() > period {
window.pop_front();
}
let count = Decimal::from(window.len());
if count > Decimal::ZERO {
let sum: Decimal = window.iter().copied().sum();
sum / count
} else {
Decimal::ZERO
}
}
#[cfg(test)]
#[path = "atr_tests.rs"]
mod tests;