indicators/volatility/
keltner_channels.rs1use std::collections::HashMap;
16
17use crate::error::IndicatorError;
18use crate::functions::{self};
19use crate::indicator::{Indicator, IndicatorOutput};
20use crate::registry::{param_f64, param_usize};
21use crate::types::Candle;
22
23#[derive(Debug, Clone)]
26pub struct KeltnerParams {
27 pub period: usize,
29 pub multiplier: f64,
31}
32
33impl Default for KeltnerParams {
34 fn default() -> Self {
35 Self {
36 period: 20,
37 multiplier: 2.0,
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
45pub struct KeltnerChannels {
46 pub params: KeltnerParams,
47}
48
49impl KeltnerChannels {
50 pub fn new(params: KeltnerParams) -> Self {
51 Self { params }
52 }
53
54 pub fn with_period(period: usize) -> Self {
55 Self::new(KeltnerParams {
56 period,
57 ..Default::default()
58 })
59 }
60}
61
62impl Indicator for KeltnerChannels {
65 fn name(&self) -> &'static str {
66 "KeltnerChannels"
67 }
68 fn required_len(&self) -> usize {
69 self.params.period
70 }
71 fn required_columns(&self) -> &[&'static str] {
72 &["high", "low", "close"]
73 }
74
75 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
76 self.check_len(candles)?;
77
78 let n = candles.len();
79 let p = self.params.period;
80 let mult = self.params.multiplier;
81
82 let close: Vec<f64> = candles.iter().map(|c| c.close).collect();
84 let middle = functions::ema(&close, p)?;
85
86 let mut tr = vec![0.0f64; n];
88 for i in 0..n {
89 let hl = candles[i].high - candles[i].low;
90 tr[i] = if i == 0 {
91 hl
92 } else {
93 let pc = candles[i - 1].close;
94 hl.max((candles[i].high - pc).abs())
95 .max((candles[i].low - pc).abs())
96 };
97 }
98
99 let mut atr = vec![0.0f64; n];
101 for i in 0..n {
102 let start = (i + 1).saturating_sub(p);
103 atr[i] = tr[start..=i].iter().sum::<f64>() / (i - start + 1) as f64;
104 }
105
106 let mut upper = vec![f64::NAN; n];
108 let mut lower = vec![f64::NAN; n];
109 for i in 0..n {
110 if !middle[i].is_nan() {
111 upper[i] = middle[i] + mult * atr[i];
112 lower[i] = middle[i] - mult * atr[i];
113 }
114 }
115
116 Ok(IndicatorOutput::from_pairs([
117 ("KC_upper".to_string(), upper),
118 ("KC_lower".to_string(), lower),
119 ("KC_middle".to_string(), middle),
120 ]))
121 }
122}
123
124pub fn factory<S: ::std::hash::BuildHasher>(params: &HashMap<String, String, S>) -> Result<Box<dyn Indicator>, IndicatorError> {
127 Ok(Box::new(KeltnerChannels::new(KeltnerParams {
128 period: param_usize(params, "period", 20)?,
129 multiplier: param_f64(params, "multiplier", 2.0)?,
130 })))
131}
132
133#[cfg(test)]
136mod tests {
137 use super::*;
138
139 fn candles(n: usize) -> Vec<Candle> {
140 (0..n)
141 .map(|i| Candle {
142 time: i64::try_from(i).expect("time index fits i64"),
143 open: 10.0 + i as f64 * 0.05,
144 high: 11.0 + i as f64 * 0.10,
145 low: 9.0 - i as f64 * 0.05,
146 close: 10.0 + i as f64 * 0.10,
147 volume: 100.0,
148 })
149 .collect()
150 }
151
152 #[test]
153 fn kc_three_output_columns() {
154 let out = KeltnerChannels::with_period(10)
155 .calculate(&candles(15))
156 .unwrap();
157 assert!(out.get("KC_upper").is_some());
158 assert!(out.get("KC_lower").is_some());
159 assert!(out.get("KC_middle").is_some());
160 }
161
162 #[test]
163 fn kc_upper_above_lower() {
164 let out = KeltnerChannels::with_period(5)
165 .calculate(&candles(20))
166 .unwrap();
167 let upper = out.get("KC_upper").unwrap();
168 let lower = out.get("KC_lower").unwrap();
169 for i in 0..20 {
170 if !upper[i].is_nan() {
171 assert!(upper[i] > lower[i], "upper <= lower at {i}");
172 }
173 }
174 }
175
176 #[test]
177 fn kc_middle_is_ema() {
178 use crate::functions;
180 let bars = candles(20);
181 let closes: Vec<f64> = bars.iter().map(|c| c.close).collect();
182 let ema = functions::ema(&closes, 5).unwrap();
183 let out = KeltnerChannels::with_period(5).calculate(&bars).unwrap();
184 let middle = out.get("KC_middle").unwrap();
185 for i in 0..20 {
186 if !ema[i].is_nan() {
187 assert!((middle[i] - ema[i]).abs() < 1e-9, "middle≠EMA at {i}");
188 }
189 }
190 }
191
192 #[test]
193 fn kc_insufficient_data_errors() {
194 assert!(
195 KeltnerChannels::with_period(10)
196 .calculate(&candles(5))
197 .is_err()
198 );
199 }
200
201 #[test]
202 fn factory_creates_keltner() {
203 assert_eq!(factory(&HashMap::new()).unwrap().name(), "KeltnerChannels");
204 }
205}