finance_query/indicators/
bollinger.rs1use super::{IndicatorError, Result, sma::sma};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
8pub struct BollingerBands {
9 pub upper: Vec<Option<f64>>,
11
12 pub middle: Vec<Option<f64>>,
14
15 pub lower: Vec<Option<f64>>,
17}
18
19pub fn bollinger_bands(
50 data: &[f64],
51 period: usize,
52 std_dev_multiplier: f64,
53) -> Result<BollingerBands> {
54 if period == 0 {
55 return Err(IndicatorError::InvalidPeriod(
56 "Period must be greater than 0".to_string(),
57 ));
58 }
59
60 if data.len() < period {
61 return Err(IndicatorError::InsufficientData {
62 need: period,
63 got: data.len(),
64 });
65 }
66
67 let middle = sma(data, period);
69
70 let mut upper = Vec::with_capacity(data.len());
71 let mut lower = Vec::with_capacity(data.len());
72
73 for i in 0..data.len() {
75 if i + 1 < period {
76 upper.push(None);
77 lower.push(None);
78 } else {
79 let window = &data[i + 1 - period..=i];
81 let mean = middle[i].unwrap(); let variance: f64 =
84 window.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / period as f64;
85
86 let std_dev = variance.sqrt();
87
88 upper.push(Some(mean + std_dev_multiplier * std_dev));
89 lower.push(Some(mean - std_dev_multiplier * std_dev));
90 }
91 }
92
93 Ok(BollingerBands {
94 upper,
95 middle,
96 lower,
97 })
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[test]
105 fn test_bollinger_bands_basic() {
106 let data: Vec<f64> = (1..=30).map(|x| x as f64).collect();
107 let result = bollinger_bands(&data, 20, 2.0).unwrap();
108
109 assert_eq!(result.upper.len(), 30);
110 assert_eq!(result.middle.len(), 30);
111 assert_eq!(result.lower.len(), 30);
112
113 for i in 0..19 {
115 assert_eq!(result.upper[i], None);
116 assert_eq!(result.middle[i], None);
117 assert_eq!(result.lower[i], None);
118 }
119
120 assert!(result.upper[19].is_some());
122 assert!(result.middle[19].is_some());
123 assert!(result.lower[19].is_some());
124
125 for i in 19..30 {
127 let upper = result.upper[i].unwrap();
128 let middle = result.middle[i].unwrap();
129 let lower = result.lower[i].unwrap();
130
131 assert!(
132 upper > middle,
133 "Upper ({}) should be > middle ({}) at index {}",
134 upper,
135 middle,
136 i
137 );
138 assert!(
139 middle > lower,
140 "Middle ({}) should be > lower ({}) at index {}",
141 middle,
142 lower,
143 i
144 );
145 }
146 }
147
148 #[test]
149 fn test_bollinger_bands_constant_price() {
150 let data = vec![50.0; 30];
152 let result = bollinger_bands(&data, 20, 2.0).unwrap();
153
154 for i in 19..30 {
156 let upper = result.upper[i].unwrap();
157 let middle = result.middle[i].unwrap();
158 let lower = result.lower[i].unwrap();
159
160 assert!((upper - middle).abs() < 0.0001);
161 assert!((middle - lower).abs() < 0.0001);
162 assert!((middle - 50.0).abs() < 0.0001);
163 }
164 }
165
166 #[test]
167 fn test_bollinger_bands_insufficient_data() {
168 let data = vec![1.0, 2.0, 3.0];
169 let result = bollinger_bands(&data, 20, 2.0);
170
171 assert!(result.is_err());
172 }
173
174 #[test]
175 fn test_bollinger_bands_volatility() {
176 let low_vol_data: Vec<f64> = (1..=30).map(|x| 50.0 + (x % 2) as f64).collect();
178 let high_vol_data: Vec<f64> = (1..=30).map(|x| 50.0 + (x % 10) as f64 * 5.0).collect();
179
180 let low_vol_result = bollinger_bands(&low_vol_data, 20, 2.0).unwrap();
181 let high_vol_result = bollinger_bands(&high_vol_data, 20, 2.0).unwrap();
182
183 let low_vol_width = low_vol_result.upper[29].unwrap() - low_vol_result.lower[29].unwrap();
185 let high_vol_width =
186 high_vol_result.upper[29].unwrap() - high_vol_result.lower[29].unwrap();
187
188 assert!(
189 high_vol_width > low_vol_width,
190 "High volatility bands ({}) should be wider than low volatility bands ({})",
191 high_vol_width,
192 low_vol_width
193 );
194 }
195}