indicators/volatility/
bollinger.rs1use std::collections::HashMap;
18
19use crate::error::IndicatorError;
20use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
21use crate::registry::{param_f64, param_str, param_usize};
22use crate::types::Candle;
23
24#[derive(Debug, Clone)]
27pub struct BollingerParams {
28 pub period: usize,
30 pub std_dev: f64,
32 pub column: PriceColumn,
34}
35
36impl Default for BollingerParams {
37 fn default() -> Self {
38 Self {
39 period: 20,
40 std_dev: 2.0,
41 column: PriceColumn::Close,
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
49pub struct BollingerBands {
50 pub params: BollingerParams,
51}
52
53impl BollingerBands {
54 pub fn new(params: BollingerParams) -> Self {
55 Self { params }
56 }
57
58 pub fn with_period(period: usize) -> Self {
59 Self::new(BollingerParams {
60 period,
61 ..Default::default()
62 })
63 }
64}
65
66fn rolling_std(prices: &[f64], period: usize) -> Vec<f64> {
70 let n = prices.len();
71 let mut out = vec![f64::NAN; n];
72 for i in (period - 1)..n {
73 let window = &prices[(i + 1 - period)..=i];
74 let mean: f64 = window.iter().sum::<f64>() / period as f64;
75 let var: f64 =
76 window.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (period - 1) as f64; out[i] = var.sqrt();
78 }
79 out
80}
81
82impl Indicator for BollingerBands {
85 fn name(&self) -> &'static str {
86 "BollingerBands"
87 }
88 fn required_len(&self) -> usize {
89 self.params.period
90 }
91 fn required_columns(&self) -> &[&'static str] {
92 &["close"]
93 }
94
95 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
96 self.check_len(candles)?;
97
98 let prices = self.params.column.extract(candles);
99 let num_bars = prices.len();
100 let period = self.params.period;
101 let std_mult = self.params.std_dev;
102
103 let mut middle = vec![f64::NAN; num_bars];
105 for i in (period - 1)..num_bars {
106 middle[i] = prices[(i + 1 - period)..=i].iter().sum::<f64>() / period as f64;
107 }
108
109 let std = rolling_std(&prices, period);
110
111 let mut upper = vec![f64::NAN; num_bars];
112 let mut lower = vec![f64::NAN; num_bars];
113 let mut bandwidth = vec![f64::NAN; num_bars];
114 let mut pct_b = vec![f64::NAN; num_bars];
115
116 for i in (period - 1)..num_bars {
117 let upper_val = middle[i] + std_mult * std[i];
118 let lower_val = middle[i] - std_mult * std[i];
119 upper[i] = upper_val;
120 lower[i] = lower_val;
121 bandwidth[i] = if middle[i] == 0.0 {
122 f64::NAN
123 } else {
124 (upper_val - lower_val) / middle[i]
125 };
126 let band_range = upper_val - lower_val;
127 pct_b[i] = if band_range == 0.0 {
128 f64::NAN
129 } else {
130 (prices[i] - lower_val) / band_range
131 };
132 }
133
134 Ok(IndicatorOutput::from_pairs([
135 ("BB_middle".to_string(), middle),
136 ("BB_upper".to_string(), upper),
137 ("BB_lower".to_string(), lower),
138 ("BB_bandwidth".to_string(), bandwidth),
139 ("BB_pct_b".to_string(), pct_b),
140 ]))
141 }
142}
143
144pub fn factory<S: ::std::hash::BuildHasher>(
147 params: &HashMap<String, String, S>,
148) -> Result<Box<dyn Indicator>, IndicatorError> {
149 let period = param_usize(params, "period", 20)?;
150 let std_dev = param_f64(params, "std_dev", 2.0)?;
151 let column = match param_str(params, "column", "close") {
152 "open" => PriceColumn::Open,
153 "high" => PriceColumn::High,
154 "low" => PriceColumn::Low,
155 _ => PriceColumn::Close,
156 };
157 Ok(Box::new(BollingerBands::new(BollingerParams {
158 period,
159 std_dev,
160 column,
161 })))
162}
163
164#[cfg(test)]
167mod tests {
168 use super::*;
169
170 fn candles(closes: &[f64]) -> Vec<Candle> {
171 closes
172 .iter()
173 .enumerate()
174 .map(|(i, &c)| Candle {
175 time: i64::try_from(i).expect("time index fits i64"),
176 open: c,
177 high: c + 1.0,
178 low: c - 1.0,
179 close: c,
180 volume: 100.0,
181 })
182 .collect()
183 }
184
185 #[test]
186 fn bb_five_output_columns() {
187 let closes = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
188 let out = BollingerBands::with_period(5)
189 .calculate(&candles(&closes))
190 .unwrap();
191 assert!(out.get("BB_middle").is_some());
192 assert!(out.get("BB_upper").is_some());
193 assert!(out.get("BB_lower").is_some());
194 assert!(out.get("BB_bandwidth").is_some());
195 assert!(out.get("BB_pct_b").is_some());
196 }
197
198 #[test]
199 fn bb_upper_always_above_lower() {
200 let closes: Vec<f64> = (1..=20).map(|x| x as f64).collect();
201 let out = BollingerBands::with_period(5)
202 .calculate(&candles(&closes))
203 .unwrap();
204 let upper = out.get("BB_upper").unwrap();
205 let lower = out.get("BB_lower").unwrap();
206 for i in 4..20 {
207 assert!(upper[i] >= lower[i], "upper < lower at {i}");
208 }
209 }
210
211 #[test]
212 fn bb_correct_warm_up() {
213 let closes = vec![1.0, 2.0, 3.0, 4.0, 5.0];
214 let out = BollingerBands::with_period(5)
215 .calculate(&candles(&closes))
216 .unwrap();
217 let mid = out.get("BB_middle").unwrap();
218 for (i, &v) in mid.iter().enumerate().take(4) {
219 assert!(v.is_nan(), "expected NaN at {i}");
220 }
221 assert!(!mid[4].is_nan());
222 }
223
224 #[test]
225 fn bb_constant_prices_bandwidth_zero() {
226 let closes = vec![10.0f64; 10];
227 let out = BollingerBands::with_period(5)
228 .calculate(&candles(&closes))
229 .unwrap();
230 let bw = out.get("BB_bandwidth").unwrap();
231 assert!(bw[9].abs() < 1e-9 || bw[9].is_nan());
233 }
234
235 #[test]
236 fn bb_middle_equals_sma() {
237 let closes = [1.0, 2.0, 3.0, 4.0, 5.0];
239 let out = BollingerBands::with_period(5)
240 .calculate(&candles(&closes))
241 .unwrap();
242 let mid = out.get("BB_middle").unwrap();
243 assert!((mid[4] - 3.0).abs() < 1e-9, "SMA mismatch: {}", mid[4]);
244 }
245
246 #[test]
247 fn factory_creates_bollinger() {
248 assert_eq!(factory(&HashMap::new()).unwrap().name(), "BollingerBands");
249 }
250}