use std::collections::VecDeque;
use crate::indicators::trend::{Ema, Sma};
use crate::indicators::volatility::Atr;
use crate::indicators::{Candle, Indicator, IndicatorError};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct KeltnerChannelsResult {
pub middle: f64,
pub upper: f64,
pub lower: f64,
pub bandwidth: f64,
}
#[derive(Debug)]
pub struct KeltnerChannels {
ema_period: usize,
atr_period: usize,
multiplier: f64,
candle_buffer: VecDeque<Candle>,
current_ema: Option<f64>,
current_atr: Option<f64>,
}
impl KeltnerChannels {
pub fn new(
ema_period: usize,
atr_period: usize,
multiplier: f64,
) -> Result<Self, IndicatorError> {
if ema_period < 1 {
return Err(IndicatorError::InvalidParameter(
"EMA period must be at least 1".to_string(),
));
}
if atr_period < 1 {
return Err(IndicatorError::InvalidParameter(
"ATR period must be at least 1".to_string(),
));
}
if multiplier <= 0.0 {
return Err(IndicatorError::InvalidParameter(
"ATR multiplier must be positive".to_string(),
));
}
Ok(Self {
ema_period,
atr_period,
multiplier,
candle_buffer: VecDeque::with_capacity(ema_period.max(atr_period)),
current_ema: None,
current_atr: None,
})
}
}
impl Indicator<Candle, KeltnerChannelsResult> for KeltnerChannels {
fn calculate(&mut self, data: &[Candle]) -> Result<Vec<KeltnerChannelsResult>, IndicatorError> {
let min_data_len = self.ema_period.max(self.atr_period);
if data.len() < min_data_len {
return Err(IndicatorError::InsufficientData(format!(
"Keltner Channels needs at least {} data points",
min_data_len
)));
}
let n = data.len();
let mut result = Vec::with_capacity(n - min_data_len + 1);
self.reset();
let mut ema = Ema::new(self.ema_period)?;
let close_prices: Vec<f64> = data.iter().map(|c| c.close).collect();
let ema_values = ema.calculate(&close_prices)?;
let mut atr = Atr::new(self.atr_period)?;
let atr_values = atr.calculate(data)?;
let ema_offset = self.atr_period.saturating_sub(self.ema_period);
let atr_offset = self.ema_period.saturating_sub(self.atr_period);
for i in 0..atr_values.len().min(ema_values.len() - ema_offset) {
let ema = ema_values[i + ema_offset];
let atr = atr_values[i + atr_offset];
let upper = ema + (self.multiplier * atr);
let lower = ema - (self.multiplier * atr);
let bandwidth = (upper - lower) / ema;
result.push(KeltnerChannelsResult {
middle: ema,
upper,
lower,
bandwidth,
});
}
self.current_ema = Some(*ema_values.last().unwrap());
self.current_atr = Some(*atr_values.last().unwrap());
for candle in data.iter().take(n).skip(n - min_data_len) {
self.candle_buffer.push_back(*candle);
}
Ok(result)
}
fn next(&mut self, value: Candle) -> Result<Option<KeltnerChannelsResult>, IndicatorError> {
self.candle_buffer.push_back(value);
let min_data_len = self.ema_period.max(self.atr_period);
if self.candle_buffer.len() > min_data_len {
self.candle_buffer.pop_front();
}
if self.candle_buffer.len() < min_data_len {
return Ok(None);
}
if let Some(current_ema) = self.current_ema {
let alpha = 2.0 / (self.ema_period as f64 + 1.0);
let new_ema = (value.close - current_ema) * alpha + current_ema;
self.current_ema = Some(new_ema);
} else {
let mut ema = Ema::new(self.ema_period)?;
let close_prices: Vec<f64> = self.candle_buffer.iter().map(|c| c.close).collect();
let ema_values = ema.calculate(&close_prices)?;
self.current_ema = Some(*ema_values.last().unwrap());
}
let mut atr = Atr::new(self.atr_period)?;
let candles: Vec<Candle> = self.candle_buffer.iter().cloned().collect();
let atr_values = atr.calculate(&candles)?;
self.current_atr = Some(*atr_values.last().unwrap());
let ema = self.current_ema.unwrap();
let atr = self.current_atr.unwrap();
let upper = ema + (self.multiplier * atr);
let lower = ema - (self.multiplier * atr);
let bandwidth = (upper - lower) / ema;
Ok(Some(KeltnerChannelsResult {
middle: ema,
upper,
lower,
bandwidth,
}))
}
fn reset(&mut self) {
self.candle_buffer.clear();
self.current_ema = None;
self.current_atr = None;
}
}
#[derive(Debug, Clone)]
pub struct KeltnerChannelsPrice {
ema_period: usize,
atr_period: usize,
price_buffer: VecDeque<f64>,
atr_buffer: VecDeque<f64>,
current_ema: Option<f64>,
current_atr: Option<f64>,
}
impl KeltnerChannelsPrice {
pub fn new(ema_period: usize, atr_period: usize) -> Result<Self, IndicatorError> {
if ema_period < 1 {
return Err(IndicatorError::InvalidParameter(
"EMA period must be at least 1".to_string(),
));
}
if atr_period < 1 {
return Err(IndicatorError::InvalidParameter(
"ATR period must be at least 1".to_string(),
));
}
Ok(Self {
ema_period,
atr_period,
price_buffer: VecDeque::with_capacity(ema_period),
atr_buffer: VecDeque::with_capacity(atr_period),
current_ema: None,
current_atr: None,
})
}
fn calculate_volatility(&self, prices: &[f64]) -> Result<Vec<f64>, IndicatorError> {
if prices.len() < 2 {
return Err(IndicatorError::InsufficientData(
"Need at least 2 prices to calculate volatility".to_string(),
));
}
let mut volatility = Vec::with_capacity(prices.len() - 1);
for i in 1..prices.len() {
let price_change = (prices[i] - prices[i - 1]).abs();
volatility.push(price_change);
}
let all_zeros = volatility.iter().all(|&x| x == 0.0);
if all_zeros {
volatility = volatility.iter().map(|_| 0.001).collect();
}
let mut sma = Sma::new(self.atr_period)?;
let vol_sma = sma.calculate(&volatility)?;
Ok(vol_sma)
}
}
impl Indicator<f64, f64> for KeltnerChannelsPrice {
fn calculate(&mut self, data: &[f64]) -> Result<Vec<f64>, IndicatorError> {
let min_data_len = self.ema_period + self.atr_period;
if data.len() < min_data_len {
return Err(IndicatorError::InsufficientData(format!(
"Keltner Channels needs at least {} data points for price-only mode",
min_data_len
)));
}
self.reset();
let mut ema = Ema::new(self.ema_period)?;
let ema_values = ema.calculate(data)?;
let volatility = self.calculate_volatility(data)?;
let first_valid_idx = (self.ema_period - 1).max(self.atr_period);
let result_len = data.len().saturating_sub(first_valid_idx);
let mut result = Vec::with_capacity(result_len);
for i in 0..result_len {
let ema_idx = i + first_valid_idx - (self.ema_period - 1);
if ema_idx < ema_values.len() {
result.push(ema_values[ema_idx]);
}
}
self.current_ema = ema_values.last().cloned();
self.current_atr = volatility.last().cloned();
for price in data.iter().rev().take(self.ema_period) {
self.price_buffer.push_front(*price);
}
for vol in volatility.iter().rev().take(self.atr_period) {
self.atr_buffer.push_front(*vol);
}
Ok(result)
}
fn next(&mut self, value: f64) -> Result<Option<f64>, IndicatorError> {
self.price_buffer.push_back(value);
if self.price_buffer.len() > self.ema_period {
self.price_buffer.pop_front();
}
if self.price_buffer.len() < 2 {
return Ok(None);
}
let prev_price = self.price_buffer[self.price_buffer.len() - 2];
let new_volatility = (value - prev_price).abs();
self.atr_buffer.push_back(new_volatility);
if self.atr_buffer.len() > self.atr_period {
self.atr_buffer.pop_front();
}
if let Some(current_ema) = self.current_ema {
let alpha = 2.0 / (self.ema_period as f64 + 1.0);
let new_ema = value * alpha + current_ema * (1.0 - alpha);
self.current_ema = Some(new_ema);
} else if self.price_buffer.len() >= self.ema_period {
let mut ema = Ema::new(self.ema_period)?;
let prices: Vec<f64> = self.price_buffer.iter().cloned().collect();
let ema_values = ema.calculate(&prices)?;
self.current_ema = Some(*ema_values.last().unwrap());
} else {
return Ok(None);
}
if self.atr_buffer.len() >= self.atr_period {
let sum: f64 = self.atr_buffer.iter().sum();
let new_atr = sum / self.atr_period as f64;
self.current_atr = Some(new_atr);
} else {
return Ok(None);
}
Ok(self.current_ema)
}
fn reset(&mut self) {
self.price_buffer.clear();
self.atr_buffer.clear();
self.current_ema = None;
self.current_atr = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::indicators::Candle;
#[test]
fn test_keltner_channels_new() {
assert!(KeltnerChannels::new(20, 10, 2.0).is_ok());
assert!(KeltnerChannels::new(0, 10, 2.0).is_err());
assert!(KeltnerChannels::new(20, 10, -1.0).is_err());
}
#[test]
fn test_keltner_channels_calculation() {
let mut kc = KeltnerChannels::new(3, 3, 2.0).unwrap();
let candles = vec![
Candle {
timestamp: 1,
open: 10.0,
high: 13.0,
low: 7.0,
close: 10.0,
volume: 1000.0,
},
Candle {
timestamp: 2,
open: 10.0,
high: 13.0,
low: 7.0,
close: 10.0,
volume: 1000.0,
},
Candle {
timestamp: 3,
open: 10.0,
high: 13.0,
low: 7.0,
close: 10.0,
volume: 1000.0,
},
Candle {
timestamp: 4,
open: 10.0,
high: 13.0,
low: 7.0,
close: 10.0,
volume: 1000.0,
},
Candle {
timestamp: 5,
open: 10.0,
high: 13.0,
low: 7.0,
close: 10.0,
volume: 1000.0,
},
];
let result = kc.calculate(&candles).unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result[0].middle, 10.0);
assert!((result[0].upper - 22.0).abs() < 0.1);
assert!((result[0].lower - (-2.0)).abs() < 0.1);
}
#[test]
fn test_keltner_channels_reset() {
let mut kc = KeltnerChannels::new(3, 3, 2.0).unwrap();
let candle1 = Candle {
timestamp: 1,
open: 10.0,
high: 12.0,
low: 9.0,
close: 11.0,
volume: 1000.0,
};
kc.next(candle1).unwrap();
kc.reset();
assert!(kc.current_ema.is_none());
assert!(kc.current_atr.is_none());
}
#[test]
fn test_keltner_channels_price_new() {
assert!(KeltnerChannelsPrice::new(20, 10).is_ok());
assert!(KeltnerChannelsPrice::new(0, 10).is_err());
}
#[test]
fn test_keltner_channels_price_calculation() {
let mut kc = KeltnerChannelsPrice::new(3, 2).unwrap();
let prices = vec![10.0, 10.0, 10.0, 10.0, 10.0, 10.0];
let result = kc.calculate(&prices).unwrap();
println!("Result length: {}", result.len());
assert!(!result.is_empty());
for val in result.iter() {
assert!((val - 10.0).abs() < 0.001);
}
}
#[test]
fn test_keltner_channels_price_next() {
let mut kc = KeltnerChannelsPrice::new(3, 2).unwrap();
let _ = kc.next(10.0); let _ = kc.next(10.0); let _ = kc.next(10.0); let _ = kc.next(10.0);
let result = kc.next(10.0).unwrap();
assert!(result.is_some());
if let Some(value) = result {
assert!((value - 10.0).abs() < 0.001);
}
}
#[test]
fn test_keltner_channels_price_reset() {
let mut kc = KeltnerChannelsPrice::new(3, 2).unwrap();
let _ = kc.next(10.0);
let _ = kc.next(10.0);
let _ = kc.next(10.0);
let _ = kc.next(10.0);
kc.reset();
assert!(kc.current_ema.is_none());
assert!(kc.current_atr.is_none());
assert_eq!(kc.price_buffer.len(), 0);
assert_eq!(kc.atr_buffer.len(), 0);
}
}