use crate::indicators::volatility::calculate_atr;
use crate::util::dataframe_utils::check_window_size;
use polars::prelude::*;
pub fn calculate_keltner_channels(
df: &DataFrame,
window: usize,
multiplier: f64,
) -> PolarsResult<DataFrame> {
check_window_size(df, window, "Keltner Channels")?;
if !df.schema().contains("high")
|| !df.schema().contains("low")
|| !df.schema().contains("close")
{
return Err(PolarsError::ShapeMismatch(
"DataFrame must contain 'high', 'low', and 'close' columns for Keltner Channels calculation".into(),
));
}
let close = df.column("close")?.f64()?;
let mut middle_band = Vec::with_capacity(df.height());
let smoothing_factor = 2.0 / (window as f64 + 1.0);
let mut sum = 0.0;
let mut count = 0;
for i in 0..window.min(df.height()) {
let val = close.get(i).unwrap_or(f64::NAN);
if !val.is_nan() {
sum += val;
count += 1;
}
}
let first_ema = if count > 0 {
sum / count as f64
} else {
f64::NAN
};
for _ in 0..(window - 1) {
middle_band.push(f64::NAN);
}
middle_band.push(first_ema);
let mut prev_ema = first_ema;
for i in window..df.height() {
let close_val = close.get(i).unwrap_or(f64::NAN);
if !close_val.is_nan() && !prev_ema.is_nan() {
let ema = close_val * smoothing_factor + prev_ema * (1.0 - smoothing_factor);
middle_band.push(ema);
prev_ema = ema;
} else {
middle_band.push(f64::NAN);
}
}
let atr = calculate_atr(df, window)?;
let mut upper_band = Vec::with_capacity(df.height());
let mut lower_band = Vec::with_capacity(df.height());
for (i, mid) in middle_band.iter().enumerate().take(df.height()) {
let atr_val = atr.f64()?.get(i).unwrap_or(f64::NAN);
if !mid.is_nan() && !atr_val.is_nan() {
upper_band.push(mid + multiplier * atr_val);
lower_band.push(mid - multiplier * atr_val);
} else {
upper_band.push(f64::NAN);
lower_band.push(f64::NAN);
}
}
let middle_series = Series::new("keltner_middle".into(), middle_band);
let upper_series = Series::new("keltner_upper".into(), upper_band);
let lower_series = Series::new("keltner_lower".into(), lower_band);
DataFrame::new(vec![
upper_series.into(),
middle_series.into(),
lower_series.into(),
])
}