Skip to main content

finance_query/indicators/
bollinger.rs

1//! Bollinger Bands indicator.
2
3use super::{IndicatorError, Result, sma::sma};
4use serde::{Deserialize, Serialize};
5
6/// Bollinger Bands result containing upper, middle, and lower bands.
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
8pub struct BollingerBands {
9    /// Upper band (SMA + std_dev * multiplier)
10    pub upper: Vec<Option<f64>>,
11
12    /// Middle band (SMA)
13    pub middle: Vec<Option<f64>>,
14
15    /// Lower band (SMA - std_dev * multiplier)
16    pub lower: Vec<Option<f64>>,
17}
18
19/// Calculate Bollinger Bands.
20///
21/// Bollinger Bands consist of a middle band (SMA) and upper/lower bands that are
22/// standard deviations away from the middle band. They help identify volatility and
23/// potential overbought/oversold conditions.
24///
25/// # Arguments
26///
27/// * `data` - Price data (typically close prices)
28/// * `period` - Number of periods for the SMA (typically 20)
29/// * `std_dev_multiplier` - Number of standard deviations (typically 2.0)
30///
31/// # Formula
32///
33/// - Middle Band = SMA(period)
34/// - Upper Band = Middle Band + (std_dev_multiplier × standard deviation)
35/// - Lower Band = Middle Band - (std_dev_multiplier × standard deviation)
36///
37/// # Example
38///
39/// ```
40/// use finance_query::indicators::bollinger_bands;
41///
42/// let prices: Vec<f64> = (1..=30).map(|x| x as f64 + (x % 3) as f64).collect();
43/// let result = bollinger_bands(&prices, 20, 2.0).unwrap();
44///
45/// assert_eq!(result.upper.len(), prices.len());
46/// assert_eq!(result.middle.len(), prices.len());
47/// assert_eq!(result.lower.len(), prices.len());
48/// ```
49pub fn bollinger_bands(
50    data: &[f64],
51    period: usize,
52    std_dev_multiplier: f64,
53) -> Result<BollingerBands> {
54    if period == 0 {
55        return Err(IndicatorError::InvalidPeriod(
56            "Period must be greater than 0".to_string(),
57        ));
58    }
59
60    if data.len() < period {
61        return Err(IndicatorError::InsufficientData {
62            need: period,
63            got: data.len(),
64        });
65    }
66
67    // Calculate middle band (SMA)
68    let middle = sma(data, period);
69
70    let mut upper = Vec::with_capacity(data.len());
71    let mut lower = Vec::with_capacity(data.len());
72
73    // Calculate upper and lower bands
74    for i in 0..data.len() {
75        if i + 1 < period {
76            upper.push(None);
77            lower.push(None);
78        } else {
79            // Calculate standard deviation for this window
80            let window = &data[i + 1 - period..=i];
81            let mean = middle[i].unwrap(); // We know this exists because i >= period - 1
82
83            let variance: f64 =
84                window.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / period as f64;
85
86            let std_dev = variance.sqrt();
87
88            upper.push(Some(mean + std_dev_multiplier * std_dev));
89            lower.push(Some(mean - std_dev_multiplier * std_dev));
90        }
91    }
92
93    Ok(BollingerBands {
94        upper,
95        middle,
96        lower,
97    })
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn test_bollinger_bands_basic() {
106        let data: Vec<f64> = (1..=30).map(|x| x as f64).collect();
107        let result = bollinger_bands(&data, 20, 2.0).unwrap();
108
109        assert_eq!(result.upper.len(), 30);
110        assert_eq!(result.middle.len(), 30);
111        assert_eq!(result.lower.len(), 30);
112
113        // First 19 values should be None
114        for i in 0..19 {
115            assert_eq!(result.upper[i], None);
116            assert_eq!(result.middle[i], None);
117            assert_eq!(result.lower[i], None);
118        }
119
120        // Values after period should exist
121        assert!(result.upper[19].is_some());
122        assert!(result.middle[19].is_some());
123        assert!(result.lower[19].is_some());
124
125        // Upper should be > Middle > Lower
126        for i in 19..30 {
127            let upper = result.upper[i].unwrap();
128            let middle = result.middle[i].unwrap();
129            let lower = result.lower[i].unwrap();
130
131            assert!(
132                upper > middle,
133                "Upper ({}) should be > middle ({}) at index {}",
134                upper,
135                middle,
136                i
137            );
138            assert!(
139                middle > lower,
140                "Middle ({}) should be > lower ({}) at index {}",
141                middle,
142                lower,
143                i
144            );
145        }
146    }
147
148    #[test]
149    fn test_bollinger_bands_constant_price() {
150        // Constant price should have zero standard deviation
151        let data = vec![50.0; 30];
152        let result = bollinger_bands(&data, 20, 2.0).unwrap();
153
154        // All bands should be equal when std dev is 0
155        for i in 19..30 {
156            let upper = result.upper[i].unwrap();
157            let middle = result.middle[i].unwrap();
158            let lower = result.lower[i].unwrap();
159
160            assert!((upper - middle).abs() < 0.0001);
161            assert!((middle - lower).abs() < 0.0001);
162            assert!((middle - 50.0).abs() < 0.0001);
163        }
164    }
165
166    #[test]
167    fn test_bollinger_bands_insufficient_data() {
168        let data = vec![1.0, 2.0, 3.0];
169        let result = bollinger_bands(&data, 20, 2.0);
170
171        assert!(result.is_err());
172    }
173
174    #[test]
175    fn test_bollinger_bands_volatility() {
176        // Higher volatility should create wider bands
177        let low_vol_data: Vec<f64> = (1..=30).map(|x| 50.0 + (x % 2) as f64).collect();
178        let high_vol_data: Vec<f64> = (1..=30).map(|x| 50.0 + (x % 10) as f64 * 5.0).collect();
179
180        let low_vol_result = bollinger_bands(&low_vol_data, 20, 2.0).unwrap();
181        let high_vol_result = bollinger_bands(&high_vol_data, 20, 2.0).unwrap();
182
183        // Compare band width at the last data point
184        let low_vol_width = low_vol_result.upper[29].unwrap() - low_vol_result.lower[29].unwrap();
185        let high_vol_width =
186            high_vol_result.upper[29].unwrap() - high_vol_result.lower[29].unwrap();
187
188        assert!(
189            high_vol_width > low_vol_width,
190            "High volatility bands ({}) should be wider than low volatility bands ({})",
191            high_vol_width,
192            low_vol_width
193        );
194    }
195}