use std::collections::VecDeque;
use crate::indicators::utils::{calculate_sma, standard_deviation, validate_data_length};
use crate::indicators::{validate_period, Candle, Indicator};
use crate::IndicatorError;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct BollingerBandsResult {
pub middle: f64,
pub upper: f64,
pub lower: f64,
pub bandwidth: f64,
}
#[derive(Debug)]
pub struct BollingerBands {
period: usize,
k: f64,
values: VecDeque<f64>,
sma: Option<f64>,
}
impl BollingerBands {
pub fn new(period: usize, k: f64) -> Result<Self, IndicatorError> {
validate_period(period, 1)?;
if k <= 0.0 {
return Err(IndicatorError::InvalidParameter(
"Standard deviation multiplier must be positive".to_string(),
));
}
Ok(Self {
period,
k,
values: VecDeque::with_capacity(period),
sma: None,
})
}
fn calculate_sma(&self) -> f64 {
self.values.iter().sum::<f64>() / self.values.len() as f64
}
pub fn reset_state(&mut self) {
self.values.clear();
self.sma = None;
}
}
impl Indicator<f64, BollingerBandsResult> for BollingerBands {
fn calculate(&mut self, data: &[f64]) -> Result<Vec<BollingerBandsResult>, IndicatorError> {
validate_data_length(data, self.period)?;
let n = data.len();
let mut result = Vec::with_capacity(n - self.period + 1);
self.reset_state();
let sma_values = calculate_sma(data, self.period)?;
for i in 0..sma_values.len() {
let period_data = &data[i..(i + self.period)];
let sma = sma_values[i];
let std_dev = standard_deviation(period_data, Some(sma))?;
let upper = sma + (self.k * std_dev);
let lower = sma - (self.k * std_dev);
let bandwidth = (upper - lower) / sma;
result.push(BollingerBandsResult {
middle: sma,
upper,
lower,
bandwidth,
});
}
for value in data.iter().take(n).skip(n - self.period) {
self.values.push_back(*value);
}
self.sma = Some(self.calculate_sma());
Ok(result)
}
fn next(&mut self, value: f64) -> Result<Option<BollingerBandsResult>, IndicatorError> {
self.values.push_back(value);
if self.values.len() > self.period {
self.values.pop_front();
}
if self.values.len() == self.period {
let sma = self.calculate_sma();
let period_data: Vec<f64> = self.values.iter().cloned().collect();
let std_dev = standard_deviation(&period_data, Some(sma))?;
let upper = sma + (self.k * std_dev);
let lower = sma - (self.k * std_dev);
self.sma = Some(sma);
let bandwidth = (upper - lower) / sma;
Ok(Some(BollingerBandsResult {
middle: sma,
upper,
lower,
bandwidth,
}))
} else {
Ok(None)
}
}
fn reset(&mut self) {
self.reset_state();
}
}
impl Indicator<Candle, BollingerBandsResult> for BollingerBands {
fn calculate(&mut self, data: &[Candle]) -> Result<Vec<BollingerBandsResult>, IndicatorError> {
validate_data_length(data, self.period)?;
let close_prices: Vec<f64> = data.iter().map(|candle| candle.close).collect();
self.calculate(&close_prices)
}
fn next(&mut self, candle: Candle) -> Result<Option<BollingerBandsResult>, IndicatorError> {
let close_price = candle.close;
self.next(close_price)
}
fn reset(&mut self) {
self.reset_state();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bollinger_bands_new() {
assert!(BollingerBands::new(20, 2.0).is_ok());
assert!(BollingerBands::new(0, 2.0).is_err());
assert!(BollingerBands::new(20, -1.0).is_err());
}
#[test]
fn test_bollinger_bands_calculation() {
let mut bb = BollingerBands::new(3, 2.0).unwrap();
let prices = vec![5.0, 7.0, 9.0, 11.0, 13.0];
let result = bb.calculate(&prices).unwrap();
assert_eq!(result.len(), 3);
assert!((result[0].middle - 7.0).abs() < 0.1);
assert!((result[0].upper - 11.0).abs() < 2.0);
assert!((result[0].lower - 3.0).abs() < 2.0);
assert!((result[1].middle - 9.0).abs() < 0.1);
assert!(result[1].upper > result[1].middle); assert!(result[1].lower < result[1].middle); }
#[test]
fn test_bollinger_bands_next() {
let mut bb = BollingerBands::new(3, 2.0).unwrap();
assert_eq!(bb.next(5.0).unwrap(), None);
assert_eq!(bb.next(7.0).unwrap(), None);
let result = bb.next(9.0).unwrap();
assert!(result.is_some());
let bands = result.unwrap();
assert!((bands.middle - 7.0).abs() < 0.1);
assert!((bands.upper - 11.0).abs() < 2.0); assert!((bands.lower - 3.0).abs() < 2.0); }
#[test]
fn test_bollinger_bands_reset() {
let mut bb = BollingerBands::new(3, 2.0).unwrap();
bb.next(5.0).unwrap();
bb.next(7.0).unwrap();
bb.next(9.0).unwrap();
bb.reset_state();
assert_eq!(bb.next(11.0).unwrap(), None);
}
#[test]
fn test_bollinger_bands_calculation_with_candles() {
let mut bb = BollingerBands::new(3, 2.0).unwrap();
let candles = vec![
Candle {
timestamp: 1,
open: 4.5,
high: 5.5,
low: 4.5,
close: 5.0,
volume: 1000.0,
},
Candle {
timestamp: 2,
open: 6.5,
high: 7.5,
low: 6.5,
close: 7.0,
volume: 1000.0,
},
Candle {
timestamp: 3,
open: 8.5,
high: 9.5,
low: 8.5,
close: 9.0,
volume: 1000.0,
},
Candle {
timestamp: 4,
open: 10.5,
high: 11.5,
low: 10.5,
close: 11.0,
volume: 1000.0,
},
Candle {
timestamp: 5,
open: 12.5,
high: 13.5,
low: 12.5,
close: 13.0,
volume: 1000.0,
},
];
let result = bb.calculate(&candles).unwrap();
assert_eq!(result.len(), 3);
assert!((result[0].middle - 7.0).abs() < 0.1);
assert!((result[0].upper - 11.0).abs() < 2.0);
assert!((result[0].lower - 3.0).abs() < 2.0);
let prices = vec![5.0, 7.0, 9.0, 11.0, 13.0];
let mut bb_prices = BollingerBands::new(3, 2.0).unwrap();
let price_result = bb_prices.calculate(&prices).unwrap();
for (res_candle, res_price) in result.iter().zip(price_result.iter()) {
assert!((res_candle.middle - res_price.middle).abs() < 0.000001);
assert!((res_candle.upper - res_price.upper).abs() < 0.000001);
assert!((res_candle.lower - res_price.lower).abs() < 0.000001);
assert!((res_candle.bandwidth - res_price.bandwidth).abs() < 0.000001);
}
}
#[test]
fn test_bollinger_bands_next_with_candles() {
let mut bb = BollingerBands::new(3, 2.0).unwrap();
let candle1 = Candle {
timestamp: 1,
open: 4.5,
high: 5.5,
low: 4.5,
close: 5.0,
volume: 1000.0,
};
let candle2 = Candle {
timestamp: 2,
open: 6.5,
high: 7.5,
low: 6.5,
close: 7.0,
volume: 1000.0,
};
assert_eq!(bb.next(candle1).unwrap(), None);
assert_eq!(bb.next(candle2).unwrap(), None);
let candle3 = Candle {
timestamp: 3,
open: 8.5,
high: 9.5,
low: 8.5,
close: 9.0,
volume: 1000.0,
};
let result = bb.next(candle3).unwrap();
assert!(result.is_some());
let bands = result.unwrap();
assert!((bands.middle - 7.0).abs() < 0.1);
assert!((bands.upper - 11.0).abs() < 2.0); assert!((bands.lower - 3.0).abs() < 2.0);
let mut bb_prices = BollingerBands::new(3, 2.0).unwrap();
bb_prices.next(5.0).unwrap();
bb_prices.next(7.0).unwrap();
let price_result = bb_prices.next(9.0).unwrap().unwrap();
assert!((bands.middle - price_result.middle).abs() < 0.000001);
assert!((bands.upper - price_result.upper).abs() < 0.000001);
assert!((bands.lower - price_result.lower).abs() < 0.000001);
assert!((bands.bandwidth - price_result.bandwidth).abs() < 0.000001);
}
#[test]
fn test_bollinger_bands_reset_with_candles() {
let mut bb = BollingerBands::new(3, 2.0).unwrap();
let candle1 = Candle {
timestamp: 1,
open: 4.5,
high: 5.5,
low: 4.5,
close: 5.0,
volume: 1000.0,
};
let candle2 = Candle {
timestamp: 2,
open: 6.5,
high: 7.5,
low: 6.5,
close: 7.0,
volume: 1000.0,
};
let candle3 = Candle {
timestamp: 3,
open: 8.5,
high: 9.5,
low: 8.5,
close: 9.0,
volume: 1000.0,
};
bb.next(candle1).unwrap();
bb.next(candle2).unwrap();
bb.next(candle3).unwrap();
bb.reset_state();
let candle4 = Candle {
timestamp: 4,
open: 10.5,
high: 11.5,
low: 10.5,
close: 11.0,
volume: 1000.0,
};
assert_eq!(bb.next(candle4).unwrap(), None);
}
}