indicators/volatility/
bollinger.rs1use std::collections::HashMap;
18
19use crate::error::IndicatorError;
20use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
21use crate::registry::{param_f64, param_str, param_usize};
22use crate::types::Candle;
23
24#[derive(Debug, Clone)]
27pub struct BollingerParams {
28 pub period: usize,
30 pub std_dev: f64,
32 pub column: PriceColumn,
34}
35
36impl Default for BollingerParams {
37 fn default() -> Self {
38 Self {
39 period: 20,
40 std_dev: 2.0,
41 column: PriceColumn::Close,
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
49pub struct BollingerBands {
50 pub params: BollingerParams,
51}
52
53impl BollingerBands {
54 pub fn new(params: BollingerParams) -> Self {
55 Self { params }
56 }
57
58 pub fn with_period(period: usize) -> Self {
59 Self::new(BollingerParams {
60 period,
61 ..Default::default()
62 })
63 }
64}
65
66fn rolling_std(prices: &[f64], period: usize) -> Vec<f64> {
70 let n = prices.len();
71 let mut out = vec![f64::NAN; n];
72 for i in (period - 1)..n {
73 let window = &prices[(i + 1 - period)..=i];
74 let mean: f64 = window.iter().sum::<f64>() / period as f64;
75 let var: f64 =
76 window.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (period - 1) as f64; out[i] = var.sqrt();
78 }
79 out
80}
81
82impl Indicator for BollingerBands {
85 fn name(&self) -> &'static str {
86 "BollingerBands"
87 }
88 fn required_len(&self) -> usize {
89 self.params.period
90 }
91 fn required_columns(&self) -> &[&'static str] {
92 &["close"]
93 }
94
95 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
96 self.check_len(candles)?;
97
98 let prices = self.params.column.extract(candles);
99 let num_bars = prices.len();
100 let period = self.params.period;
101 let std_mult = self.params.std_dev;
102
103 let mut middle = vec![f64::NAN; num_bars];
105 for i in (period - 1)..num_bars {
106 middle[i] = prices[(i + 1 - period)..=i].iter().sum::<f64>() / period as f64;
107 }
108
109 let std = rolling_std(&prices, period);
110
111 let mut upper = vec![f64::NAN; num_bars];
112 let mut lower = vec![f64::NAN; num_bars];
113 let mut bandwidth = vec![f64::NAN; num_bars];
114 let mut pct_b = vec![f64::NAN; num_bars];
115
116 for i in (period - 1)..num_bars {
117 let upper_val = middle[i] + std_mult * std[i];
118 let lower_val = middle[i] - std_mult * std[i];
119 upper[i] = upper_val;
120 lower[i] = lower_val;
121 bandwidth[i] = if middle[i] == 0.0 {
122 f64::NAN
123 } else {
124 (upper_val - lower_val) / middle[i]
125 };
126 let band_range = upper_val - lower_val;
127 pct_b[i] = if band_range == 0.0 {
128 f64::NAN
129 } else {
130 (prices[i] - lower_val) / band_range
131 };
132 }
133
134 Ok(IndicatorOutput::from_pairs([
135 ("BB_middle".to_string(), middle),
136 ("BB_upper".to_string(), upper),
137 ("BB_lower".to_string(), lower),
138 ("BB_bandwidth".to_string(), bandwidth),
139 ("BB_pct_b".to_string(), pct_b),
140 ]))
141 }
142}
143
144pub fn factory<S: ::std::hash::BuildHasher>(params: &HashMap<String, String, S>) -> Result<Box<dyn Indicator>, IndicatorError> {
147 let period = param_usize(params, "period", 20)?;
148 let std_dev = param_f64(params, "std_dev", 2.0)?;
149 let column = match param_str(params, "column", "close") {
150 "open" => PriceColumn::Open,
151 "high" => PriceColumn::High,
152 "low" => PriceColumn::Low,
153 _ => PriceColumn::Close,
154 };
155 Ok(Box::new(BollingerBands::new(BollingerParams {
156 period,
157 std_dev,
158 column,
159 })))
160}
161
162#[cfg(test)]
165mod tests {
166 use super::*;
167
168 fn candles(closes: &[f64]) -> Vec<Candle> {
169 closes
170 .iter()
171 .enumerate()
172 .map(|(i, &c)| Candle {
173 time: i64::try_from(i).expect("time index fits i64"),
174 open: c,
175 high: c + 1.0,
176 low: c - 1.0,
177 close: c,
178 volume: 100.0,
179 })
180 .collect()
181 }
182
183 #[test]
184 fn bb_five_output_columns() {
185 let closes = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
186 let out = BollingerBands::with_period(5)
187 .calculate(&candles(&closes))
188 .unwrap();
189 assert!(out.get("BB_middle").is_some());
190 assert!(out.get("BB_upper").is_some());
191 assert!(out.get("BB_lower").is_some());
192 assert!(out.get("BB_bandwidth").is_some());
193 assert!(out.get("BB_pct_b").is_some());
194 }
195
196 #[test]
197 fn bb_upper_always_above_lower() {
198 let closes: Vec<f64> = (1..=20).map(|x| x as f64).collect();
199 let out = BollingerBands::with_period(5)
200 .calculate(&candles(&closes))
201 .unwrap();
202 let upper = out.get("BB_upper").unwrap();
203 let lower = out.get("BB_lower").unwrap();
204 for i in 4..20 {
205 assert!(upper[i] >= lower[i], "upper < lower at {i}");
206 }
207 }
208
209 #[test]
210 fn bb_correct_warm_up() {
211 let closes = vec![1.0, 2.0, 3.0, 4.0, 5.0];
212 let out = BollingerBands::with_period(5)
213 .calculate(&candles(&closes))
214 .unwrap();
215 let mid = out.get("BB_middle").unwrap();
216 for (i, &v) in mid.iter().enumerate().take(4) {
217 assert!(v.is_nan(), "expected NaN at {i}");
218 }
219 assert!(!mid[4].is_nan());
220 }
221
222 #[test]
223 fn bb_constant_prices_bandwidth_zero() {
224 let closes = vec![10.0f64; 10];
225 let out = BollingerBands::with_period(5)
226 .calculate(&candles(&closes))
227 .unwrap();
228 let bw = out.get("BB_bandwidth").unwrap();
229 assert!(bw[9].abs() < 1e-9 || bw[9].is_nan());
231 }
232
233 #[test]
234 fn bb_middle_equals_sma() {
235 let closes = [1.0, 2.0, 3.0, 4.0, 5.0];
237 let out = BollingerBands::with_period(5)
238 .calculate(&candles(&closes))
239 .unwrap();
240 let mid = out.get("BB_middle").unwrap();
241 assert!((mid[4] - 3.0).abs() < 1e-9, "SMA mismatch: {}", mid[4]);
242 }
243
244 #[test]
245 fn factory_creates_bollinger() {
246 assert_eq!(factory(&HashMap::new()).unwrap().name(), "BollingerBands");
247 }
248}