Skip to main content

indicators/volatility/
keltner_channels.rs

1//! Keltner Channels.
2//!
3//! Ported from `keltner_channels.py` :: `class KeltnerChannelsIndicator`.
4//!
5//! # Algorithm
6//!
7//! 1. `middle[i] = EMA(close, period)`
8//! 2. `true_range[i] = max(H−L, |H−prev_C|, |L−prev_C|)` (H−L for the first bar)
9//! 3. `atr[i] = rolling mean of true_range` (min_periods=1, matching Python)
10//! 4. `upper[i] = middle[i] + multiplier × atr[i]`
11//! 5. `lower[i] = middle[i] − multiplier × atr[i]`
12//!
13//! Output columns: `"KC_upper"`, `"KC_lower"`, `"KC_middle"`.
14
15use std::collections::HashMap;
16
17use crate::error::IndicatorError;
18use crate::functions::{self};
19use crate::indicator::{Indicator, IndicatorOutput};
20use crate::registry::{param_f64, param_usize};
21use crate::types::Candle;
22
23// ── Params ────────────────────────────────────────────────────────────────────
24
25#[derive(Debug, Clone)]
26pub struct KeltnerParams {
27    /// EMA period (also used for ATR look-back).  Python default: 20.
28    pub period: usize,
29    /// ATR multiplier for band width.  Python default: 2.0.
30    pub multiplier: f64,
31}
32
33impl Default for KeltnerParams {
34    fn default() -> Self {
35        Self {
36            period: 20,
37            multiplier: 2.0,
38        }
39    }
40}
41
42// ── Indicator struct ──────────────────────────────────────────────────────────
43
44#[derive(Debug, Clone)]
45pub struct KeltnerChannels {
46    pub params: KeltnerParams,
47}
48
49impl KeltnerChannels {
50    pub fn new(params: KeltnerParams) -> Self {
51        Self { params }
52    }
53
54    pub fn with_period(period: usize) -> Self {
55        Self::new(KeltnerParams {
56            period,
57            ..Default::default()
58        })
59    }
60}
61
62// ── Indicator impl ────────────────────────────────────────────────────────────
63
64impl Indicator for KeltnerChannels {
65    fn name(&self) -> &'static str {
66        "KeltnerChannels"
67    }
68    fn required_len(&self) -> usize {
69        self.params.period
70    }
71    fn required_columns(&self) -> &[&'static str] {
72        &["high", "low", "close"]
73    }
74
75    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
76        self.check_len(candles)?;
77
78        let n = candles.len();
79        let p = self.params.period;
80        let mult = self.params.multiplier;
81
82        // EMA of close → middle band
83        let close: Vec<f64> = candles.iter().map(|c| c.close).collect();
84        let middle = functions::ema(&close, p)?;
85
86        // True range (max of three measures; H-L only for bar 0)
87        let mut tr = vec![0.0f64; n];
88        for i in 0..n {
89            let hl = candles[i].high - candles[i].low;
90            tr[i] = if i == 0 {
91                hl
92            } else {
93                let pc = candles[i - 1].close;
94                hl.max((candles[i].high - pc).abs())
95                    .max((candles[i].low - pc).abs())
96            };
97        }
98
99        // Rolling mean of TR with min_periods=1 (matches Python's rolling(window, min_periods=1).mean())
100        let mut atr = vec![0.0f64; n];
101        for i in 0..n {
102            let start = (i + 1).saturating_sub(p);
103            atr[i] = tr[start..=i].iter().sum::<f64>() / (i - start + 1) as f64;
104        }
105
106        // Bands — only where middle is non-NaN (needs `period` bars of EMA warm-up)
107        let mut upper = vec![f64::NAN; n];
108        let mut lower = vec![f64::NAN; n];
109        for i in 0..n {
110            if !middle[i].is_nan() {
111                upper[i] = middle[i] + mult * atr[i];
112                lower[i] = middle[i] - mult * atr[i];
113            }
114        }
115
116        Ok(IndicatorOutput::from_pairs([
117            ("KC_upper".to_string(), upper),
118            ("KC_lower".to_string(), lower),
119            ("KC_middle".to_string(), middle),
120        ]))
121    }
122}
123
124// ── Registry factory ──────────────────────────────────────────────────────────
125
126pub fn factory<S: ::std::hash::BuildHasher>(
127    params: &HashMap<String, String, S>,
128) -> Result<Box<dyn Indicator>, IndicatorError> {
129    Ok(Box::new(KeltnerChannels::new(KeltnerParams {
130        period: param_usize(params, "period", 20)?,
131        multiplier: param_f64(params, "multiplier", 2.0)?,
132    })))
133}
134
135// ── Tests ─────────────────────────────────────────────────────────────────────
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    fn candles(n: usize) -> Vec<Candle> {
142        (0..n)
143            .map(|i| Candle {
144                time: i64::try_from(i).expect("time index fits i64"),
145                open: 10.0 + i as f64 * 0.05,
146                high: 11.0 + i as f64 * 0.10,
147                low: 9.0 - i as f64 * 0.05,
148                close: 10.0 + i as f64 * 0.10,
149                volume: 100.0,
150            })
151            .collect()
152    }
153
154    #[test]
155    fn kc_three_output_columns() {
156        let out = KeltnerChannels::with_period(10)
157            .calculate(&candles(15))
158            .unwrap();
159        assert!(out.get("KC_upper").is_some());
160        assert!(out.get("KC_lower").is_some());
161        assert!(out.get("KC_middle").is_some());
162    }
163
164    #[test]
165    fn kc_upper_above_lower() {
166        let out = KeltnerChannels::with_period(5)
167            .calculate(&candles(20))
168            .unwrap();
169        let upper = out.get("KC_upper").unwrap();
170        let lower = out.get("KC_lower").unwrap();
171        for i in 0..20 {
172            if !upper[i].is_nan() {
173                assert!(upper[i] > lower[i], "upper <= lower at {i}");
174            }
175        }
176    }
177
178    #[test]
179    fn kc_middle_is_ema() {
180        // Middle band must equal EMA(close, period) exactly.
181        use crate::functions;
182        let bars = candles(20);
183        let closes: Vec<f64> = bars.iter().map(|c| c.close).collect();
184        let ema = functions::ema(&closes, 5).unwrap();
185        let out = KeltnerChannels::with_period(5).calculate(&bars).unwrap();
186        let middle = out.get("KC_middle").unwrap();
187        for i in 0..20 {
188            if !ema[i].is_nan() {
189                assert!((middle[i] - ema[i]).abs() < 1e-9, "middle≠EMA at {i}");
190            }
191        }
192    }
193
194    #[test]
195    fn kc_insufficient_data_errors() {
196        assert!(
197            KeltnerChannels::with_period(10)
198                .calculate(&candles(5))
199                .is_err()
200        );
201    }
202
203    #[test]
204    fn factory_creates_keltner() {
205        assert_eq!(factory(&HashMap::new()).unwrap().name(), "KeltnerChannels");
206    }
207}