Skip to main content

indicators/volatility/
bollinger.rs

1//! Bollinger Bands.
2//!
3//! Ported from `bollinger.py` :: `class BollingerBands`.
4//!
5//! # Algorithm
6//!
7//! 1. `middle[i] = SMA(prices, period)`
8//! 2. `std[i]    = rolling sample std-dev (ddof=1, matches pandas)`
9//! 3. `upper[i]  = middle[i] + std_dev × std[i]`
10//! 4. `lower[i]  = middle[i] − std_dev × std[i]`
11//! 5. `bandwidth = (upper − lower) / middle`
12//! 6. `percent_b = (price − lower) / (upper − lower)`
13//!
14//! Output columns: `"BB_middle"`, `"BB_upper"`, `"BB_lower"`,
15//! `"BB_bandwidth"`, `"BB_pct_b"`.
16
17use std::collections::HashMap;
18
19use crate::error::IndicatorError;
20use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
21use crate::registry::{param_f64, param_str, param_usize};
22use crate::types::Candle;
23
24// ── Params ────────────────────────────────────────────────────────────────────
25
26#[derive(Debug, Clone)]
27pub struct BollingerParams {
28    /// Rolling window size.  Python default: 20.
29    pub period: usize,
30    /// Number of standard deviations.  Python default: 2.0.
31    pub std_dev: f64,
32    /// Price field.  Python default: `"close"`.
33    pub column: PriceColumn,
34}
35
36impl Default for BollingerParams {
37    fn default() -> Self {
38        Self {
39            period: 20,
40            std_dev: 2.0,
41            column: PriceColumn::Close,
42        }
43    }
44}
45
46// ── Indicator struct ──────────────────────────────────────────────────────────
47
48#[derive(Debug, Clone)]
49pub struct BollingerBands {
50    pub params: BollingerParams,
51}
52
53impl BollingerBands {
54    pub fn new(params: BollingerParams) -> Self {
55        Self { params }
56    }
57
58    pub fn with_period(period: usize) -> Self {
59        Self::new(BollingerParams {
60            period,
61            ..Default::default()
62        })
63    }
64}
65
66// ── Helpers ───────────────────────────────────────────────────────────────────
67
68/// Rolling sample standard deviation (ddof=1), matching `pandas rolling().std()`.
69fn rolling_std(prices: &[f64], period: usize) -> Vec<f64> {
70    let n = prices.len();
71    let mut out = vec![f64::NAN; n];
72    for i in (period - 1)..n {
73        let window = &prices[(i + 1 - period)..=i];
74        let mean: f64 = window.iter().sum::<f64>() / period as f64;
75        let var: f64 =
76            window.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (period - 1) as f64; // ddof=1
77        out[i] = var.sqrt();
78    }
79    out
80}
81
82// ── Indicator impl ────────────────────────────────────────────────────────────
83
84impl Indicator for BollingerBands {
85    fn name(&self) -> &'static str {
86        "BollingerBands"
87    }
88    fn required_len(&self) -> usize {
89        self.params.period
90    }
91    fn required_columns(&self) -> &[&'static str] {
92        &["close"]
93    }
94
95    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
96        self.check_len(candles)?;
97
98        let prices = self.params.column.extract(candles);
99        let num_bars = prices.len();
100        let period = self.params.period;
101        let std_mult = self.params.std_dev;
102
103        // Rolling SMA
104        let mut middle = vec![f64::NAN; num_bars];
105        for i in (period - 1)..num_bars {
106            middle[i] = prices[(i + 1 - period)..=i].iter().sum::<f64>() / period as f64;
107        }
108
109        let std = rolling_std(&prices, period);
110
111        let mut upper = vec![f64::NAN; num_bars];
112        let mut lower = vec![f64::NAN; num_bars];
113        let mut bandwidth = vec![f64::NAN; num_bars];
114        let mut pct_b = vec![f64::NAN; num_bars];
115
116        for i in (period - 1)..num_bars {
117            let upper_val = middle[i] + std_mult * std[i];
118            let lower_val = middle[i] - std_mult * std[i];
119            upper[i] = upper_val;
120            lower[i] = lower_val;
121            bandwidth[i] = if middle[i] == 0.0 {
122                f64::NAN
123            } else {
124                (upper_val - lower_val) / middle[i]
125            };
126            let band_range = upper_val - lower_val;
127            pct_b[i] = if band_range == 0.0 {
128                f64::NAN
129            } else {
130                (prices[i] - lower_val) / band_range
131            };
132        }
133
134        Ok(IndicatorOutput::from_pairs([
135            ("BB_middle".to_string(), middle),
136            ("BB_upper".to_string(), upper),
137            ("BB_lower".to_string(), lower),
138            ("BB_bandwidth".to_string(), bandwidth),
139            ("BB_pct_b".to_string(), pct_b),
140        ]))
141    }
142}
143
144// ── Registry factory ──────────────────────────────────────────────────────────
145
146pub fn factory<S: ::std::hash::BuildHasher>(params: &HashMap<String, String, S>) -> Result<Box<dyn Indicator>, IndicatorError> {
147    let period = param_usize(params, "period", 20)?;
148    let std_dev = param_f64(params, "std_dev", 2.0)?;
149    let column = match param_str(params, "column", "close") {
150        "open" => PriceColumn::Open,
151        "high" => PriceColumn::High,
152        "low" => PriceColumn::Low,
153        _ => PriceColumn::Close,
154    };
155    Ok(Box::new(BollingerBands::new(BollingerParams {
156        period,
157        std_dev,
158        column,
159    })))
160}
161
162// ── Tests ─────────────────────────────────────────────────────────────────────
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    fn candles(closes: &[f64]) -> Vec<Candle> {
169        closes
170            .iter()
171            .enumerate()
172            .map(|(i, &c)| Candle {
173                time: i64::try_from(i).expect("time index fits i64"),
174                open: c,
175                high: c + 1.0,
176                low: c - 1.0,
177                close: c,
178                volume: 100.0,
179            })
180            .collect()
181    }
182
183    #[test]
184    fn bb_five_output_columns() {
185        let closes = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
186        let out = BollingerBands::with_period(5)
187            .calculate(&candles(&closes))
188            .unwrap();
189        assert!(out.get("BB_middle").is_some());
190        assert!(out.get("BB_upper").is_some());
191        assert!(out.get("BB_lower").is_some());
192        assert!(out.get("BB_bandwidth").is_some());
193        assert!(out.get("BB_pct_b").is_some());
194    }
195
196    #[test]
197    fn bb_upper_always_above_lower() {
198        let closes: Vec<f64> = (1..=20).map(|x| x as f64).collect();
199        let out = BollingerBands::with_period(5)
200            .calculate(&candles(&closes))
201            .unwrap();
202        let upper = out.get("BB_upper").unwrap();
203        let lower = out.get("BB_lower").unwrap();
204        for i in 4..20 {
205            assert!(upper[i] >= lower[i], "upper < lower at {i}");
206        }
207    }
208
209    #[test]
210    fn bb_correct_warm_up() {
211        let closes = vec![1.0, 2.0, 3.0, 4.0, 5.0];
212        let out = BollingerBands::with_period(5)
213            .calculate(&candles(&closes))
214            .unwrap();
215        let mid = out.get("BB_middle").unwrap();
216        for (i, &v) in mid.iter().enumerate().take(4) {
217            assert!(v.is_nan(), "expected NaN at {i}");
218        }
219        assert!(!mid[4].is_nan());
220    }
221
222    #[test]
223    fn bb_constant_prices_bandwidth_zero() {
224        let closes = vec![10.0f64; 10];
225        let out = BollingerBands::with_period(5)
226            .calculate(&candles(&closes))
227            .unwrap();
228        let bw = out.get("BB_bandwidth").unwrap();
229        // std = 0 → upper == lower == middle → bandwidth = 0
230        assert!(bw[9].abs() < 1e-9 || bw[9].is_nan());
231    }
232
233    #[test]
234    fn bb_middle_equals_sma() {
235        // SMA(5) of [1..5] = 3.0
236        let closes = [1.0, 2.0, 3.0, 4.0, 5.0];
237        let out = BollingerBands::with_period(5)
238            .calculate(&candles(&closes))
239            .unwrap();
240        let mid = out.get("BB_middle").unwrap();
241        assert!((mid[4] - 3.0).abs() < 1e-9, "SMA mismatch: {}", mid[4]);
242    }
243
244    #[test]
245    fn factory_creates_bollinger() {
246        assert_eq!(factory(&HashMap::new()).unwrap().name(), "BollingerBands");
247    }
248}