Skip to main content

indicators/volatility/
bollinger.rs

1//! Bollinger Bands.
2//!
3//! Ported from `bollinger.py` :: `class BollingerBands`.
4//!
5//! # Algorithm
6//!
7//! 1. `middle[i] = SMA(prices, period)`
8//! 2. `std[i]    = rolling sample std-dev (ddof=1, matches pandas)`
9//! 3. `upper[i]  = middle[i] + std_dev × std[i]`
10//! 4. `lower[i]  = middle[i] − std_dev × std[i]`
11//! 5. `bandwidth = (upper − lower) / middle`
12//! 6. `percent_b = (price − lower) / (upper − lower)`
13//!
14//! Output columns: `"BB_middle"`, `"BB_upper"`, `"BB_lower"`,
15//! `"BB_bandwidth"`, `"BB_pct_b"`.
16
17use std::collections::HashMap;
18
19use crate::error::IndicatorError;
20use crate::indicator::{Indicator, IndicatorOutput, PriceColumn};
21use crate::registry::{param_f64, param_str, param_usize};
22use crate::types::Candle;
23
24// ── Params ────────────────────────────────────────────────────────────────────
25
26#[derive(Debug, Clone)]
27pub struct BollingerParams {
28    /// Rolling window size.  Python default: 20.
29    pub period: usize,
30    /// Number of standard deviations.  Python default: 2.0.
31    pub std_dev: f64,
32    /// Price field.  Python default: `"close"`.
33    pub column: PriceColumn,
34}
35
36impl Default for BollingerParams {
37    fn default() -> Self {
38        Self {
39            period: 20,
40            std_dev: 2.0,
41            column: PriceColumn::Close,
42        }
43    }
44}
45
46// ── Indicator struct ──────────────────────────────────────────────────────────
47
48#[derive(Debug, Clone)]
49pub struct BollingerBands {
50    pub params: BollingerParams,
51}
52
53impl BollingerBands {
54    pub fn new(params: BollingerParams) -> Self {
55        Self { params }
56    }
57
58    pub fn with_period(period: usize) -> Self {
59        Self::new(BollingerParams {
60            period,
61            ..Default::default()
62        })
63    }
64}
65
66// ── Helpers ───────────────────────────────────────────────────────────────────
67
68/// Rolling sample standard deviation (ddof=1), matching `pandas rolling().std()`.
69fn rolling_std(prices: &[f64], period: usize) -> Vec<f64> {
70    let n = prices.len();
71    let mut out = vec![f64::NAN; n];
72    for i in (period - 1)..n {
73        let window = &prices[(i + 1 - period)..=i];
74        let mean: f64 = window.iter().sum::<f64>() / period as f64;
75        let var: f64 =
76            window.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (period - 1) as f64; // ddof=1
77        out[i] = var.sqrt();
78    }
79    out
80}
81
82// ── Indicator impl ────────────────────────────────────────────────────────────
83
84impl Indicator for BollingerBands {
85    fn name(&self) -> &'static str {
86        "BollingerBands"
87    }
88    fn required_len(&self) -> usize {
89        self.params.period
90    }
91    fn required_columns(&self) -> &[&'static str] {
92        &["close"]
93    }
94
95    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
96        self.check_len(candles)?;
97
98        let prices = self.params.column.extract(candles);
99        let num_bars = prices.len();
100        let period = self.params.period;
101        let std_mult = self.params.std_dev;
102
103        // Rolling SMA
104        let mut middle = vec![f64::NAN; num_bars];
105        for i in (period - 1)..num_bars {
106            middle[i] = prices[(i + 1 - period)..=i].iter().sum::<f64>() / period as f64;
107        }
108
109        let std = rolling_std(&prices, period);
110
111        let mut upper = vec![f64::NAN; num_bars];
112        let mut lower = vec![f64::NAN; num_bars];
113        let mut bandwidth = vec![f64::NAN; num_bars];
114        let mut pct_b = vec![f64::NAN; num_bars];
115
116        for i in (period - 1)..num_bars {
117            let upper_val = middle[i] + std_mult * std[i];
118            let lower_val = middle[i] - std_mult * std[i];
119            upper[i] = upper_val;
120            lower[i] = lower_val;
121            bandwidth[i] = if middle[i] == 0.0 {
122                f64::NAN
123            } else {
124                (upper_val - lower_val) / middle[i]
125            };
126            let band_range = upper_val - lower_val;
127            pct_b[i] = if band_range == 0.0 {
128                f64::NAN
129            } else {
130                (prices[i] - lower_val) / band_range
131            };
132        }
133
134        Ok(IndicatorOutput::from_pairs([
135            ("BB_middle".to_string(), middle),
136            ("BB_upper".to_string(), upper),
137            ("BB_lower".to_string(), lower),
138            ("BB_bandwidth".to_string(), bandwidth),
139            ("BB_pct_b".to_string(), pct_b),
140        ]))
141    }
142}
143
144// ── Registry factory ──────────────────────────────────────────────────────────
145
146pub fn factory<S: ::std::hash::BuildHasher>(
147    params: &HashMap<String, String, S>,
148) -> Result<Box<dyn Indicator>, IndicatorError> {
149    let period = param_usize(params, "period", 20)?;
150    let std_dev = param_f64(params, "std_dev", 2.0)?;
151    let column = match param_str(params, "column", "close") {
152        "open" => PriceColumn::Open,
153        "high" => PriceColumn::High,
154        "low" => PriceColumn::Low,
155        _ => PriceColumn::Close,
156    };
157    Ok(Box::new(BollingerBands::new(BollingerParams {
158        period,
159        std_dev,
160        column,
161    })))
162}
163
164// ── Tests ─────────────────────────────────────────────────────────────────────
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    fn candles(closes: &[f64]) -> Vec<Candle> {
171        closes
172            .iter()
173            .enumerate()
174            .map(|(i, &c)| Candle {
175                time: i64::try_from(i).expect("time index fits i64"),
176                open: c,
177                high: c + 1.0,
178                low: c - 1.0,
179                close: c,
180                volume: 100.0,
181            })
182            .collect()
183    }
184
185    #[test]
186    fn bb_five_output_columns() {
187        let closes = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
188        let out = BollingerBands::with_period(5)
189            .calculate(&candles(&closes))
190            .unwrap();
191        assert!(out.get("BB_middle").is_some());
192        assert!(out.get("BB_upper").is_some());
193        assert!(out.get("BB_lower").is_some());
194        assert!(out.get("BB_bandwidth").is_some());
195        assert!(out.get("BB_pct_b").is_some());
196    }
197
198    #[test]
199    fn bb_upper_always_above_lower() {
200        let closes: Vec<f64> = (1..=20).map(|x| x as f64).collect();
201        let out = BollingerBands::with_period(5)
202            .calculate(&candles(&closes))
203            .unwrap();
204        let upper = out.get("BB_upper").unwrap();
205        let lower = out.get("BB_lower").unwrap();
206        for i in 4..20 {
207            assert!(upper[i] >= lower[i], "upper < lower at {i}");
208        }
209    }
210
211    #[test]
212    fn bb_correct_warm_up() {
213        let closes = vec![1.0, 2.0, 3.0, 4.0, 5.0];
214        let out = BollingerBands::with_period(5)
215            .calculate(&candles(&closes))
216            .unwrap();
217        let mid = out.get("BB_middle").unwrap();
218        for (i, &v) in mid.iter().enumerate().take(4) {
219            assert!(v.is_nan(), "expected NaN at {i}");
220        }
221        assert!(!mid[4].is_nan());
222    }
223
224    #[test]
225    fn bb_constant_prices_bandwidth_zero() {
226        let closes = vec![10.0f64; 10];
227        let out = BollingerBands::with_period(5)
228            .calculate(&candles(&closes))
229            .unwrap();
230        let bw = out.get("BB_bandwidth").unwrap();
231        // std = 0 → upper == lower == middle → bandwidth = 0
232        assert!(bw[9].abs() < 1e-9 || bw[9].is_nan());
233    }
234
235    #[test]
236    fn bb_middle_equals_sma() {
237        // SMA(5) of [1..5] = 3.0
238        let closes = [1.0, 2.0, 3.0, 4.0, 5.0];
239        let out = BollingerBands::with_period(5)
240            .calculate(&candles(&closes))
241            .unwrap();
242        let mid = out.get("BB_middle").unwrap();
243        assert!((mid[4] - 3.0).abs() < 1e-9, "SMA mismatch: {}", mid[4]);
244    }
245
246    #[test]
247    fn factory_creates_bollinger() {
248        assert_eq!(factory(&HashMap::new()).unwrap().name(), "BollingerBands");
249    }
250}