use crate::indicators::traits::Indicator;
use crate::indicators::utils::validate_data_length;
use crate::indicators::utils::validate_period;
use crate::indicators::{Candle, IndicatorError};
use std::collections::VecDeque;
#[derive(Debug)]
pub struct Atr {
period: usize,
prev_close: Option<f64>,
current_atr: Option<f64>,
tr_values: VecDeque<f64>,
}
impl Atr {
pub fn new(period: usize) -> Result<Self, IndicatorError> {
validate_period(period, 1)?;
Ok(Self {
period,
prev_close: None,
current_atr: None,
tr_values: VecDeque::with_capacity(period),
})
}
fn true_range(candle: &Candle, prev_close: Option<f64>) -> f64 {
let high_low = candle.high - candle.low;
match prev_close {
Some(prev_close) => {
let high_close = (candle.high - prev_close).abs();
let low_close = (candle.low - prev_close).abs();
high_low.max(high_close).max(low_close)
}
None => high_low,
}
}
fn initial_atr(tr_values: &[f64]) -> f64 {
tr_values.iter().sum::<f64>() / tr_values.len() as f64
}
fn smooth_atr(&self, prev_atr: f64, current_tr: f64) -> f64 {
((prev_atr * (self.period - 1) as f64) + current_tr) / self.period as f64
}
}
impl Indicator<Candle, f64> for Atr {
fn calculate(&mut self, data: &[Candle]) -> Result<Vec<f64>, IndicatorError> {
validate_data_length(data, self.period)?;
let n = data.len();
let mut result = Vec::with_capacity(n - self.period + 1);
self.reset();
let mut tr_values = Vec::with_capacity(n);
let mut prev_close = None;
for candle in data {
let tr = Self::true_range(candle, prev_close);
tr_values.push(tr);
prev_close = Some(candle.close);
}
let initial_atr = Self::initial_atr(&tr_values[0..self.period]);
result.push(initial_atr);
let mut current_atr = initial_atr;
for tr in tr_values.iter().skip(self.period) {
current_atr = self.smooth_atr(current_atr, *tr);
result.push(current_atr);
}
self.prev_close = Some(data[n - 1].close);
self.current_atr = Some(current_atr);
self.tr_values = tr_values.into_iter().skip(n - self.period).collect();
Ok(result)
}
fn next(&mut self, value: Candle) -> Result<Option<f64>, IndicatorError> {
let tr = Self::true_range(&value, self.prev_close);
self.tr_values.push_back(tr);
self.prev_close = Some(value.close);
if self.tr_values.len() > self.period {
self.tr_values.pop_front();
}
if self.tr_values.len() == self.period {
let atr = match self.current_atr {
Some(prev_atr) => self.smooth_atr(prev_atr, tr),
None => Self::initial_atr(self.tr_values.make_contiguous()),
};
self.current_atr = Some(atr);
Ok(Some(atr))
} else {
Ok(None)
}
}
fn reset(&mut self) {
self.prev_close = None;
self.current_atr = None;
self.tr_values.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
const FLOAT_EPSILON: f64 = 1e-10;
fn assert_float_eq(a: f64, b: f64) {
assert!(
(a - b).abs() < FLOAT_EPSILON,
"Expected {} but got {}",
b,
a
);
}
fn create_test_candle(timestamp: u64, open: f64, high: f64, low: f64, close: f64) -> Candle {
Candle {
timestamp,
open,
high,
low,
close,
volume: 1000.0,
}
}
#[test]
fn test_atr_new() {
assert!(Atr::new(1).is_ok());
assert!(Atr::new(14).is_ok());
assert!(Atr::new(100).is_ok());
assert!(Atr::new(0).is_err());
}
#[test]
fn test_true_range_calculation() {
let candle1 = create_test_candle(0, 10.0, 15.0, 8.0, 12.0);
assert_float_eq(Atr::true_range(&candle1, None), 7.0);
let candle2 = create_test_candle(1, 11.0, 13.0, 9.0, 10.0);
assert_float_eq(Atr::true_range(&candle2, Some(12.0)), 4.0);
let candle3 = create_test_candle(2, 8.0, 9.0, 7.0, 8.0);
assert_float_eq(Atr::true_range(&candle3, Some(10.0)), 3.0); }
#[test]
fn test_atr_calculation_basic() {
let mut atr = Atr::new(3).unwrap();
let candles = vec![
create_test_candle(0, 10.0, 12.0, 9.0, 11.0), create_test_candle(1, 11.0, 14.0, 10.0, 13.0), create_test_candle(2, 13.0, 15.0, 11.0, 14.0), create_test_candle(3, 14.0, 16.0, 12.0, 15.0), ];
let result = atr.calculate(&candles).unwrap();
assert_eq!(result.len(), 2);
assert_float_eq(result[0], 3.666666666666667);
assert_float_eq(result[1], 3.777777777777778);
}
#[test]
fn test_atr_with_gaps() {
let mut atr = Atr::new(3).unwrap();
let candles = vec![
create_test_candle(0, 10.0, 12.0, 9.0, 11.0), create_test_candle(1, 11.0, 14.0, 10.0, 13.0), create_test_candle(2, 15.0, 17.0, 14.0, 16.0), create_test_candle(3, 12.0, 13.0, 11.0, 12.0), ];
let result = atr.calculate(&candles).unwrap();
assert_eq!(result.len(), 2);
assert!(result[1] > result[0]);
}
#[test]
fn test_atr_next_value() {
let mut atr = Atr::new(3).unwrap();
assert_eq!(
atr.next(create_test_candle(0, 10.0, 12.0, 9.0, 11.0))
.unwrap(),
None
);
assert_eq!(
atr.next(create_test_candle(1, 11.0, 14.0, 10.0, 13.0))
.unwrap(),
None
);
let result = atr
.next(create_test_candle(2, 13.0, 15.0, 11.0, 14.0))
.unwrap()
.unwrap();
assert_float_eq(result, 3.666666666666667);
let result = atr
.next(create_test_candle(3, 14.0, 16.0, 12.0, 15.0))
.unwrap()
.unwrap();
assert_float_eq(result, 3.777777777777778);
}
#[test]
fn test_atr_error_handling() {
let mut atr = Atr::new(5).unwrap();
let data = vec![
create_test_candle(0, 10.0, 12.0, 9.0, 11.0),
create_test_candle(1, 11.0, 14.0, 10.0, 13.0),
create_test_candle(2, 13.0, 15.0, 11.0, 14.0),
];
assert!(matches!(
atr.calculate(&data),
Err(IndicatorError::InsufficientData(_))
));
let data = vec![];
assert!(matches!(
atr.calculate(&data),
Err(IndicatorError::InsufficientData(_))
));
}
#[test]
fn test_atr_reset() {
let mut atr = Atr::new(3).unwrap();
atr.next(create_test_candle(0, 10.0, 12.0, 9.0, 11.0))
.unwrap();
atr.next(create_test_candle(1, 11.0, 14.0, 10.0, 13.0))
.unwrap();
atr.next(create_test_candle(2, 13.0, 15.0, 11.0, 14.0))
.unwrap();
atr.reset();
assert_eq!(
atr.next(create_test_candle(3, 14.0, 16.0, 12.0, 15.0))
.unwrap(),
None
);
}
}