finance_query/indicators/
bollinger.rs1use super::{IndicatorError, Result};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
8pub struct BollingerBands {
9 pub upper: Vec<Option<f64>>,
11
12 pub middle: Vec<Option<f64>>,
14
15 pub lower: Vec<Option<f64>>,
17}
18
19pub fn bollinger_bands(
50 data: &[f64],
51 period: usize,
52 std_dev_multiplier: f64,
53) -> Result<BollingerBands> {
54 if period == 0 {
55 return Err(IndicatorError::InvalidPeriod(
56 "Period must be greater than 0".to_string(),
57 ));
58 }
59
60 if data.len() < period {
61 return Err(IndicatorError::InsufficientData {
62 need: period,
63 got: data.len(),
64 });
65 }
66
67 let period_f = period as f64;
71 let mut sum_x: f64 = data[..period].iter().sum();
72 let mut sum_x2: f64 = data[..period].iter().map(|&x| x * x).sum();
73
74 let mut upper = vec![None; data.len()];
75 let mut middle = vec![None; data.len()];
76 let mut lower = vec![None; data.len()];
77
78 let emit = |sum_x: f64, sum_x2: f64| -> (f64, f64) {
79 let mean = sum_x / period_f;
80 let variance = (sum_x2 / period_f - mean * mean).max(0.0);
81 (mean, variance.sqrt())
82 };
83
84 let (mean, std_dev) = emit(sum_x, sum_x2);
85 middle[period - 1] = Some(mean);
86 upper[period - 1] = Some(mean + std_dev_multiplier * std_dev);
87 lower[period - 1] = Some(mean - std_dev_multiplier * std_dev);
88
89 for i in period..data.len() {
90 sum_x += data[i] - data[i - period];
91 sum_x2 += data[i] * data[i] - data[i - period] * data[i - period];
92 let (mean, std_dev) = emit(sum_x, sum_x2);
93 middle[i] = Some(mean);
94 upper[i] = Some(mean + std_dev_multiplier * std_dev);
95 lower[i] = Some(mean - std_dev_multiplier * std_dev);
96 }
97
98 Ok(BollingerBands {
99 upper,
100 middle,
101 lower,
102 })
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 #[test]
110 fn test_bollinger_bands_basic() {
111 let data: Vec<f64> = (1..=30).map(|x| x as f64).collect();
112 let result = bollinger_bands(&data, 20, 2.0).unwrap();
113
114 assert_eq!(result.upper.len(), 30);
115 assert_eq!(result.middle.len(), 30);
116 assert_eq!(result.lower.len(), 30);
117
118 for i in 0..19 {
120 assert_eq!(result.upper[i], None);
121 assert_eq!(result.middle[i], None);
122 assert_eq!(result.lower[i], None);
123 }
124
125 assert!(result.upper[19].is_some());
127 assert!(result.middle[19].is_some());
128 assert!(result.lower[19].is_some());
129
130 for i in 19..30 {
132 let upper = result.upper[i].unwrap();
133 let middle = result.middle[i].unwrap();
134 let lower = result.lower[i].unwrap();
135
136 assert!(
137 upper > middle,
138 "Upper ({}) should be > middle ({}) at index {}",
139 upper,
140 middle,
141 i
142 );
143 assert!(
144 middle > lower,
145 "Middle ({}) should be > lower ({}) at index {}",
146 middle,
147 lower,
148 i
149 );
150 }
151 }
152
153 #[test]
154 fn test_bollinger_bands_constant_price() {
155 let data = vec![50.0; 30];
157 let result = bollinger_bands(&data, 20, 2.0).unwrap();
158
159 for i in 19..30 {
161 let upper = result.upper[i].unwrap();
162 let middle = result.middle[i].unwrap();
163 let lower = result.lower[i].unwrap();
164
165 assert!((upper - middle).abs() < 0.0001);
166 assert!((middle - lower).abs() < 0.0001);
167 assert!((middle - 50.0).abs() < 0.0001);
168 }
169 }
170
171 #[test]
172 fn test_bollinger_bands_insufficient_data() {
173 let data = vec![1.0, 2.0, 3.0];
174 let result = bollinger_bands(&data, 20, 2.0);
175
176 assert!(result.is_err());
177 }
178
179 #[test]
180 fn test_bollinger_bands_volatility() {
181 let low_vol_data: Vec<f64> = (1..=30).map(|x| 50.0 + (x % 2) as f64).collect();
183 let high_vol_data: Vec<f64> = (1..=30).map(|x| 50.0 + (x % 10) as f64 * 5.0).collect();
184
185 let low_vol_result = bollinger_bands(&low_vol_data, 20, 2.0).unwrap();
186 let high_vol_result = bollinger_bands(&high_vol_data, 20, 2.0).unwrap();
187
188 let low_vol_width = low_vol_result.upper[29].unwrap() - low_vol_result.lower[29].unwrap();
190 let high_vol_width =
191 high_vol_result.upper[29].unwrap() - high_vol_result.lower[29].unwrap();
192
193 assert!(
194 high_vol_width > low_vol_width,
195 "High volatility bands ({}) should be wider than low volatility bands ({})",
196 high_vol_width,
197 low_vol_width
198 );
199 }
200}