use crate::bar_indicators::average::lr::LinearRegressionMA;
use crate::bar_indicators::indicator_value::IndicatorValue;
use crate::bar_indicators::ohlcv_field::OhlcvField;
use serde::{Serialize, Deserialize};
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum StandardDeviationMode {
Simple,
Population,
Adaptive,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum RegressionSource {
Close,
Typical,
Median,
Weighted,
Ohlc4,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum StandardDeviationSignal {
StrongBreakoutUp,
BreakoutUp,
MeanReversion,
BreakoutDown,
StrongBreakoutDown,
WithinBands,
}
#[derive(Debug, Clone)]
pub struct StandardDeviationChannels {
period: usize,
std_multiplier: f64,
mode: StandardDeviationMode,
source: RegressionSource,
ohlcv_source: OhlcvField,
regression_ma: LinearRegressionMA,
price_buffer: Vec<f64>,
price_index: usize,
buffer_filled: bool,
std_deviation: f64,
upper_band_1: f64, upper_band_2: f64, upper_band_3: f64, lower_band_1: f64, lower_band_2: f64, lower_band_3: f64,
volatility_factor: f64,
adaptive_multiplier: f64,
bar_count: usize,
}
impl StandardDeviationChannels {
pub fn new(period: usize) -> Self {
Self::new_custom(
period,
2.0,
StandardDeviationMode::Simple,
RegressionSource::Close
)
}
pub fn new_custom(
period: usize,
std_multiplier: f64,
mode: StandardDeviationMode,
source: RegressionSource
) -> Self {
assert!(period > 1 && period <= 512, "Period must be between 2 and 512");
assert!(std_multiplier > 0.0, "Standard deviation multiplier must be positive");
Self {
period,
std_multiplier,
mode,
source,
ohlcv_source: OhlcvField::Close,
regression_ma: LinearRegressionMA::new(period), price_buffer: Vec::with_capacity(period),
price_index: 0,
buffer_filled: false,
std_deviation: 0.0,
upper_band_1: 0.0,
upper_band_2: 0.0,
upper_band_3: 0.0,
lower_band_1: 0.0,
lower_band_2: 0.0,
lower_band_3: 0.0,
volatility_factor: 1.0,
adaptive_multiplier: 1.0,
bar_count: 0,
}
}
pub fn with_source(
period: usize,
std_multiplier: f64,
mode: StandardDeviationMode,
source: RegressionSource,
ohlcv_source: OhlcvField
) -> Self {
assert!(period > 1 && period <= 512, "Period must be between 2 and 512");
assert!(std_multiplier > 0.0, "Standard deviation multiplier must be positive");
Self {
period,
std_multiplier,
mode,
source,
ohlcv_source,
regression_ma: LinearRegressionMA::new(period),
price_buffer: Vec::with_capacity(period),
price_index: 0,
buffer_filled: false,
std_deviation: 0.0,
upper_band_1: 0.0,
upper_band_2: 0.0,
upper_band_3: 0.0,
lower_band_1: 0.0,
lower_band_2: 0.0,
lower_band_3: 0.0,
volatility_factor: 1.0,
adaptive_multiplier: 1.0,
bar_count: 0,
}
}
pub fn new_simple(period: usize, std_multiplier: f64) -> Self {
Self::new_custom(
period,
std_multiplier,
StandardDeviationMode::Simple,
RegressionSource::Close
)
}
pub fn new_adaptive(period: usize) -> Self {
Self::new_custom(
period,
2.0,
StandardDeviationMode::Adaptive,
RegressionSource::Typical
)
}
pub fn update_bar(&mut self, open: f64, high: f64, low: f64, close: f64, volume: f64) -> (f64, f64, f64) {
self.bar_count += 1;
let price = self.get_price_by_source(open, high, low, close, volume);
let regression_value = self.regression_ma.update_bar(
price, price, price, price, volume
);
self.update_price_buffer(price);
if self.buffer_filled && self.regression_ma.is_ready() {
self.calculate_standard_deviation();
self.update_adaptive_parameters(high, low);
self.calculate_bands(regression_value);
}
(self.upper_band_2, regression_value, self.lower_band_2)
}
fn get_price_by_source(&self, open: f64, high: f64, low: f64, close: f64, volume: f64) -> f64 {
self.ohlcv_source.extract(open, high, low, close, volume)
}
fn update_price_buffer(&mut self, price: f64) {
if self.buffer_filled {
self.price_buffer[self.price_index] = price;
} else {
self.price_buffer.push(price);
}
self.price_index = (self.price_index + 1) % self.period;
if self.price_buffer.len() == self.period && !self.buffer_filled {
self.buffer_filled = true;
}
}
fn calculate_standard_deviation(&mut self) {
let buffer_len = if self.buffer_filled { self.period } else { self.price_buffer.len() };
let regression_value = self.regression_ma.value();
let variance = self.price_buffer.iter()
.take(buffer_len)
.map(|&price| {
let diff = price - regression_value.main();
diff * diff
})
.sum::<f64>();
let denominator = match self.mode {
StandardDeviationMode::Simple => buffer_len as f64,
StandardDeviationMode::Population => (buffer_len - 1) as f64,
StandardDeviationMode::Adaptive => {
buffer_len as f64 * self.volatility_factor
}
};
self.std_deviation = (variance / denominator).sqrt();
}
fn update_adaptive_parameters(&mut self, high: f64, low: f64) {
if matches!(self.mode, StandardDeviationMode::Adaptive) {
let true_range = high - low;
let avg_price = (high + low) / 2.0;
if avg_price > 0.0 {
let volatility_pct = true_range / avg_price;
self.volatility_factor = (1.0 + volatility_pct * 10.0).clamp(0.5, 2.0);
self.adaptive_multiplier = (1.0 + volatility_pct * 2.0).clamp(0.8, 1.5);
}
}
}
fn calculate_bands(&mut self, regression_value: f64) {
let effective_std = self.std_deviation * self.adaptive_multiplier;
self.upper_band_1 = regression_value + 1.0 * effective_std;
self.lower_band_1 = regression_value - 1.0 * effective_std;
self.upper_band_2 = regression_value + self.std_multiplier * effective_std;
self.lower_band_2 = regression_value - self.std_multiplier * effective_std;
self.upper_band_3 = regression_value + 3.0 * effective_std;
self.lower_band_3 = regression_value - 3.0 * effective_std;
}
pub fn value(&self) -> IndicatorValue {
IndicatorValue::Channel3 {
upper: self.upper_band_2,
middle: self.regression_ma.value().main(),
lower: self.lower_band_2,
}
}
pub fn value_tuple(&self) -> (f64, f64, f64) {
(self.upper_band_2, self.regression_ma.value().main(), self.lower_band_2)
}
pub fn all_bands(&self) -> (f64, f64, f64, f64, f64, f64, f64) {
(
self.upper_band_3,
self.upper_band_2,
self.upper_band_1,
self.regression_ma.value().main(),
self.lower_band_1,
self.lower_band_2,
self.lower_band_3,
)
}
pub fn regression_line(&self) -> f64 {
self.regression_ma.value().main()
}
pub fn channel_width(&self) -> f64 {
self.upper_band_2 - self.lower_band_2
}
pub fn position_in_channel(&self, price: f64) -> f64 {
let width = self.channel_width();
if width > 0.0 {
(price - self.lower_band_2) / width
} else {
0.5
}
}
pub fn regression_stats(&self) -> (f64, f64, f64, f64) {
(
self.regression_ma.slope(),
self.regression_ma.intercept(),
self.regression_ma.r2(),
self.std_deviation
)
}
pub fn standard_deviation(&self) -> f64 {
self.std_deviation
}
pub fn generate_signal(&self, price: f64) -> StandardDeviationSignal {
if !self.is_ready() {
return StandardDeviationSignal::WithinBands;
}
if price > self.upper_band_2 {
StandardDeviationSignal::StrongBreakoutUp
} else if price > self.upper_band_1 {
StandardDeviationSignal::BreakoutUp
} else if price < self.lower_band_2 {
StandardDeviationSignal::StrongBreakoutDown
} else if price < self.lower_band_1 {
StandardDeviationSignal::BreakoutDown
} else {
let regression_value = self.regression_ma.value().main();
let distance_to_regression = (price - regression_value).abs();
let std_distance = distance_to_regression / self.std_deviation;
if std_distance < 0.5 {
StandardDeviationSignal::MeanReversion
} else {
StandardDeviationSignal::WithinBands
}
}
}
pub fn is_breakout(&self, price: f64, sigma_level: f64) -> Option<bool> {
if !self.is_ready() {
return None;
}
let regression_value = self.regression_ma.value().main();
let threshold = regression_value + sigma_level * self.std_deviation;
if price > threshold {
Some(true) } else if price < (regression_value - sigma_level * self.std_deviation) {
Some(false) } else {
None }
}
pub fn is_mean_reversion_signal(&self, price: f64, prev_price: f64) -> bool {
if !self.is_ready() {
return false;
}
let regression_value = self.regression_ma.value().main();
let prev_distance = (prev_price - regression_value).abs();
let current_distance = (price - regression_value).abs();
current_distance < prev_distance && current_distance < self.std_deviation
}
pub fn trend_direction(&self) -> i8 {
let slope = self.regression_ma.slope();
if slope > 0.001 {
1 } else if slope < -0.001 {
-1 } else {
0 }
}
pub fn trend_strength(&self) -> f64 {
self.regression_ma.r2()
}
pub fn is_ready(&self) -> bool {
self.regression_ma.is_ready() && self.buffer_filled
}
pub fn get_params(&self) -> (usize, f64, StandardDeviationMode, RegressionSource) {
(self.period, self.std_multiplier, self.mode, self.source)
}
pub fn reset(&mut self) {
self.regression_ma.reset();
self.price_buffer.clear();
self.price_index = 0;
self.buffer_filled = false;
self.std_deviation = 0.0;
self.upper_band_1 = 0.0;
self.upper_band_2 = 0.0;
self.upper_band_3 = 0.0;
self.lower_band_1 = 0.0;
self.lower_band_2 = 0.0;
self.lower_band_3 = 0.0;
self.volatility_factor = 1.0;
self.adaptive_multiplier = 1.0;
self.bar_count = 0;
}
}
impl Default for StandardDeviationChannels {
fn default() -> Self {
Self::new(20)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_standard_deviation_channels_creation() {
let sdc = StandardDeviationChannels::new(20);
assert!(!sdc.is_ready());
assert_eq!(sdc.channel_width(), 0.0);
}
#[test]
fn test_standard_deviation_channels_warmup() {
let mut sdc = StandardDeviationChannels::new(20);
for i in 0..25 {
let price = 100.0 + (i as f64 * 0.1).sin() * 5.0;
sdc.update_bar(price, price + 1.0, price - 1.0, price, 1000.0);
}
assert!(sdc.is_ready());
}
#[test]
fn test_standard_deviation_channels_bands() {
let mut sdc = StandardDeviationChannels::new(20);
for i in 0..25 {
let price = 100.0 + i as f64;
sdc.update_bar(price, price + 1.0, price - 1.0, price, 1000.0);
}
let (u3, u2, u1, mid, l1, l2, l3) = sdc.all_bands();
assert!(u3 >= u2);
assert!(u2 >= u1);
assert!(u1 >= mid);
assert!(mid >= l1);
assert!(l1 >= l2);
assert!(l2 >= l3);
}
#[test]
fn test_standard_deviation_channels_adaptive() {
let mut sdc = StandardDeviationChannels::new_adaptive(20);
for i in 0..30 {
let price = 100.0 + (i as f64 * 0.2).sin() * 10.0;
sdc.update_bar(price, price + 2.0, price - 2.0, price, 1000.0);
}
assert!(sdc.is_ready());
assert!(sdc.standard_deviation() > 0.0);
}
#[test]
fn test_standard_deviation_channels_reset() {
let mut sdc = StandardDeviationChannels::new(20);
for i in 0..25 {
sdc.update_bar(100.0 + i as f64, 101.0, 99.0, 100.0 + i as f64, 1000.0);
}
sdc.reset();
assert!(!sdc.is_ready());
assert_eq!(sdc.channel_width(), 0.0);
}
}