use crate::indicator_error::IndicatorError;
use tracing::instrument;
#[instrument(level = "trace", skip_all, ret)]
pub fn calculate_tr(
candle_close: &[f64],
candle_high: &[f64],
candle_low: &[f64],
) -> Result<Vec<f64>, IndicatorError> {
if candle_close.is_empty() || candle_high.is_empty() || candle_low.is_empty() {
return Err(IndicatorError::EmptyData);
}
if candle_close.len() != candle_high.len() || candle_high.len() != candle_low.len() {
return Err(IndicatorError::DifferentDataLength);
}
let mut tr: Vec<f64> = Vec::with_capacity(candle_close.len());
tr.push(candle_high[0] - candle_low[0]);
tr.extend(
candle_high
.iter()
.skip(1)
.zip(candle_low.iter().skip(1))
.zip(candle_close.iter())
.map(|((&high, &low), &prev_close)| {
let h_l: f64 = high - low;
let h_pc: f64 = (high - prev_close).abs();
let l_pc: f64 = (low - prev_close).abs();
h_l.max(h_pc).max(l_pc)
}),
);
Ok(tr)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calculate_tr_valid() {
let close = vec![10.0, 12.0, 11.0];
let high = vec![11.0, 13.0, 12.0];
let low = vec![9.0, 11.0, 10.0];
let result = calculate_tr(&close, &high, &low).unwrap();
assert_eq!(result.len(), 3);
assert!((result[0] - 2.0).abs() < 1e-10);
assert!((result[1] - 3.0).abs() < 1e-10);
assert!((result[2] - 2.0).abs() < 1e-10);
}
#[test]
fn test_calculate_tr_all_gaps() {
let close = vec![100.0, 110.0, 120.0];
let high = vec![105.0, 115.0, 125.0];
let low = vec![95.0, 105.0, 115.0];
let result = calculate_tr(&close, &high, &low).unwrap();
assert_eq!(result.len(), 3);
assert!((result[0] - 10.0).abs() < 1e-10);
assert!((result[1] - 15.0).abs() < 1e-10);
assert!((result[2] - 15.0).abs() < 1e-10);
}
#[test]
fn test_calculate_tr_empty() {
assert!(matches!(
calculate_tr(&[], &[1.0], &[1.0]).unwrap_err(),
IndicatorError::EmptyData
));
}
#[test]
fn test_calculate_tr_length_mismatch() {
assert!(matches!(
calculate_tr(&[1.0, 2.0], &[1.0], &[1.0]).unwrap_err(),
IndicatorError::DifferentDataLength
));
}
#[test]
fn test_calculate_tr_single_element() {
let result = calculate_tr(&[10.0], &[12.0], &[8.0]).unwrap();
assert_eq!(result.len(), 1);
assert!((result[0] - 4.0).abs() < 1e-10);
}
#[test]
fn test_calculate_tr_abs_used_correctly() {
let close = vec![50.0, 40.0];
let high = vec![55.0, 42.0];
let low = vec![45.0, 38.0];
let result = calculate_tr(&close, &high, &low).unwrap();
assert!((result[0] - 10.0).abs() < 1e-10);
assert!((result[1] - 12.0).abs() < 1e-10);
}
}