use super::{IndicatorError, Result, atr::atr_raw, ema::ema_raw};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct KeltnerChannelsResult {
pub upper: Vec<Option<f64>>,
pub middle: Vec<Option<f64>>,
pub lower: Vec<Option<f64>>,
}
pub fn keltner_channels(
highs: &[f64],
lows: &[f64],
closes: &[f64],
period: usize,
atr_period: usize,
multiplier: f64,
) -> Result<KeltnerChannelsResult> {
if period == 0 || atr_period == 0 {
return Err(IndicatorError::InvalidPeriod(
"Periods must be greater than 0".to_string(),
));
}
let len = highs.len();
if lows.len() != len || closes.len() != len {
return Err(IndicatorError::InvalidPeriod(
"Data lengths must match".to_string(),
));
}
if len < period {
return Err(IndicatorError::InsufficientData {
need: period,
got: len,
});
}
let atr_dense = atr_raw(highs, lows, closes, atr_period)?;
keltner_with_atr_dense(closes, period, &atr_dense, atr_period, multiplier)
}
pub(crate) fn keltner_with_atr_dense(
closes: &[f64],
period: usize,
atr_dense: &[f64],
atr_period: usize,
multiplier: f64,
) -> Result<KeltnerChannelsResult> {
if period == 0 || atr_period == 0 {
return Err(IndicatorError::InvalidPeriod(
"Periods must be greater than 0".to_string(),
));
}
let len = closes.len();
if len < period {
return Err(IndicatorError::InsufficientData {
need: period,
got: len,
});
}
let ema_vals = ema_raw(closes, period);
let ema_off = period - 1;
let atr_off = atr_period - 1;
let mut upper = vec![None; len];
let mut middle = vec![None; len];
let mut lower = vec![None; len];
for (k, &ev) in ema_vals.iter().enumerate() {
let i = k + ema_off;
middle[i] = Some(ev);
if i >= atr_off {
let av = atr_dense[i - atr_off];
upper[i] = Some(ev + multiplier * av);
lower[i] = Some(ev - multiplier * av);
}
}
Ok(KeltnerChannelsResult {
upper,
middle,
lower,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_keltner_channels() {
let highs = vec![10.0; 20];
let lows = vec![8.0; 20];
let closes = vec![9.0; 20];
let result = keltner_channels(&highs, &lows, &closes, 10, 10, 2.0).unwrap();
assert_eq!(result.upper.len(), 20);
assert!(result.upper[8].is_none());
assert!(result.upper[9].is_some());
}
}