use super::{IndicatorError, Result};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct BollingerBands {
pub upper: Vec<Option<f64>>,
pub middle: Vec<Option<f64>>,
pub lower: Vec<Option<f64>>,
}
pub fn bollinger_bands(
data: &[f64],
period: usize,
std_dev_multiplier: f64,
) -> Result<BollingerBands> {
if period == 0 {
return Err(IndicatorError::InvalidPeriod(
"Period must be greater than 0".to_string(),
));
}
if data.len() < period {
return Err(IndicatorError::InsufficientData {
need: period,
got: data.len(),
});
}
let period_f = period as f64;
let mut sum_x: f64 = data[..period].iter().sum();
let mut sum_x2: f64 = data[..period].iter().map(|&x| x * x).sum();
let mut upper = vec![None; data.len()];
let mut middle = vec![None; data.len()];
let mut lower = vec![None; data.len()];
let emit = |sum_x: f64, sum_x2: f64| -> (f64, f64) {
let mean = sum_x / period_f;
let variance = (sum_x2 / period_f - mean * mean).max(0.0);
(mean, variance.sqrt())
};
let (mean, std_dev) = emit(sum_x, sum_x2);
middle[period - 1] = Some(mean);
upper[period - 1] = Some(mean + std_dev_multiplier * std_dev);
lower[period - 1] = Some(mean - std_dev_multiplier * std_dev);
for i in period..data.len() {
sum_x += data[i] - data[i - period];
sum_x2 += data[i] * data[i] - data[i - period] * data[i - period];
let (mean, std_dev) = emit(sum_x, sum_x2);
middle[i] = Some(mean);
upper[i] = Some(mean + std_dev_multiplier * std_dev);
lower[i] = Some(mean - std_dev_multiplier * std_dev);
}
Ok(BollingerBands {
upper,
middle,
lower,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bollinger_bands_basic() {
let data: Vec<f64> = (1..=30).map(|x| x as f64).collect();
let result = bollinger_bands(&data, 20, 2.0).unwrap();
assert_eq!(result.upper.len(), 30);
assert_eq!(result.middle.len(), 30);
assert_eq!(result.lower.len(), 30);
for i in 0..19 {
assert_eq!(result.upper[i], None);
assert_eq!(result.middle[i], None);
assert_eq!(result.lower[i], None);
}
assert!(result.upper[19].is_some());
assert!(result.middle[19].is_some());
assert!(result.lower[19].is_some());
for i in 19..30 {
let upper = result.upper[i].unwrap();
let middle = result.middle[i].unwrap();
let lower = result.lower[i].unwrap();
assert!(
upper > middle,
"Upper ({}) should be > middle ({}) at index {}",
upper,
middle,
i
);
assert!(
middle > lower,
"Middle ({}) should be > lower ({}) at index {}",
middle,
lower,
i
);
}
}
#[test]
fn test_bollinger_bands_constant_price() {
let data = vec![50.0; 30];
let result = bollinger_bands(&data, 20, 2.0).unwrap();
for i in 19..30 {
let upper = result.upper[i].unwrap();
let middle = result.middle[i].unwrap();
let lower = result.lower[i].unwrap();
assert!((upper - middle).abs() < 0.0001);
assert!((middle - lower).abs() < 0.0001);
assert!((middle - 50.0).abs() < 0.0001);
}
}
#[test]
fn test_bollinger_bands_insufficient_data() {
let data = vec![1.0, 2.0, 3.0];
let result = bollinger_bands(&data, 20, 2.0);
assert!(result.is_err());
}
#[test]
fn test_bollinger_bands_volatility() {
let low_vol_data: Vec<f64> = (1..=30).map(|x| 50.0 + (x % 2) as f64).collect();
let high_vol_data: Vec<f64> = (1..=30).map(|x| 50.0 + (x % 10) as f64 * 5.0).collect();
let low_vol_result = bollinger_bands(&low_vol_data, 20, 2.0).unwrap();
let high_vol_result = bollinger_bands(&high_vol_data, 20, 2.0).unwrap();
let low_vol_width = low_vol_result.upper[29].unwrap() - low_vol_result.lower[29].unwrap();
let high_vol_width =
high_vol_result.upper[29].unwrap() - high_vol_result.lower[29].unwrap();
assert!(
high_vol_width > low_vol_width,
"High volatility bands ({}) should be wider than low volatility bands ({})",
high_vol_width,
low_vol_width
);
}
}