use crate::bar_indicators::average::lr::LinearRegressionMA;
use crate::bar_indicators::indicator_value::IndicatorValue;
use crate::bar_indicators::ohlcv_field::OhlcvField;
use serde::{Serialize, Deserialize};
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
#[derive(Default)]
pub enum RegressionChannelMode {
#[default]
Standard,
Percentage,
R2Weighted,
}
#[derive(Debug, Clone)]
pub struct RegressionChannels {
period: usize,
std_dev_mult: f64,
mode: RegressionChannelMode,
source: OhlcvField,
lr: LinearRegressionMA,
price_buffer: Vec<f64>,
residuals_buffer: Vec<f64>, buffer_index: usize,
buffer_filled: bool,
upper: f64,
middle: f64, lower: f64,
slope: f64,
intercept: f64,
r_squared: f64,
std_dev: f64,
trend_direction: TrendDirection,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum TrendDirection {
Uptrend,
Downtrend,
Sideways,
}
impl RegressionChannels {
pub fn new(period: usize, std_dev_mult: f64, mode: RegressionChannelMode) -> Self {
assert!(period > 1 && period <= 512, "Period must be between 2 and 512");
assert!(std_dev_mult > 0.0, "Standard deviation multiplier must be positive");
Self {
period,
std_dev_mult,
mode,
source: OhlcvField::Close,
lr: LinearRegressionMA::new(period),
price_buffer: Vec::with_capacity(period),
residuals_buffer: Vec::with_capacity(period),
buffer_index: 0,
buffer_filled: false,
upper: 0.0,
middle: 0.0,
lower: 0.0,
slope: 0.0,
intercept: 0.0,
r_squared: 0.0,
std_dev: 0.0,
trend_direction: TrendDirection::Sideways,
}
}
pub fn with_source(period: usize, std_dev_mult: f64, mode: RegressionChannelMode, source: OhlcvField) -> Self {
assert!(period > 1 && period <= 512, "Period must be between 2 and 512");
assert!(std_dev_mult > 0.0, "Standard deviation multiplier must be positive");
Self {
period,
std_dev_mult,
mode,
source,
lr: LinearRegressionMA::new(period),
price_buffer: Vec::with_capacity(period),
residuals_buffer: Vec::with_capacity(period),
buffer_index: 0,
buffer_filled: false,
upper: 0.0,
middle: 0.0,
lower: 0.0,
slope: 0.0,
intercept: 0.0,
r_squared: 0.0,
std_dev: 0.0,
trend_direction: TrendDirection::Sideways,
}
}
pub fn new_standard(period: usize, std_dev_mult: f64) -> Self {
Self::new(period, std_dev_mult, RegressionChannelMode::Standard)
}
pub fn new_percentage(period: usize, percentage: f64) -> Self {
Self::new(period, percentage, RegressionChannelMode::Percentage)
}
pub fn new_r2_weighted(period: usize, base_mult: f64) -> Self {
Self::new(period, base_mult, RegressionChannelMode::R2Weighted)
}
pub fn update_bar(&mut self, open: f64, high: f64, low: f64, close: f64, volume: f64) -> (f64, f64, f64) {
let price = self.source.extract(open, high, low, close, volume);
self.middle = self.lr.update_bar(open, high, low, close, volume);
if self.buffer_filled {
self.price_buffer[self.buffer_index] = price;
} else {
self.price_buffer.push(price);
}
self.buffer_index = (self.buffer_index + 1) % self.period;
if self.price_buffer.len() == self.period && !self.buffer_filled {
self.buffer_filled = true;
}
if self.is_ready() {
self.update_regression_metrics();
self.calculate_channels();
} else {
self.reset_channels();
}
(self.upper, self.middle, self.lower)
}
fn update_regression_metrics(&mut self) {
self.slope = self.lr.slope();
self.intercept = self.lr.intercept();
self.r_squared = self.lr.r2();
self.trend_direction = if self.slope > 0.001 {
TrendDirection::Uptrend
} else if self.slope < -0.001 {
TrendDirection::Downtrend
} else {
TrendDirection::Sideways
};
self.calculate_residuals();
}
fn calculate_residuals(&mut self) {
self.residuals_buffer.clear();
let buffer_len = if self.buffer_filled { self.period } else { self.price_buffer.len() };
for i in 0..buffer_len {
let x = (i + 1) as f64; let actual_price = self.price_buffer[i];
let predicted_price = self.slope * x + self.intercept;
let residual = actual_price - predicted_price;
if self.residuals_buffer.len() < 512 {
self.residuals_buffer.push(residual);
}
}
if !self.residuals_buffer.is_empty() {
let mean_residual = self.residuals_buffer.iter().sum::<f64>() / self.residuals_buffer.len() as f64;
let variance = self.residuals_buffer.iter()
.map(|&residual| (residual - mean_residual).powi(2))
.sum::<f64>() / self.residuals_buffer.len() as f64;
self.std_dev = variance.sqrt();
}
}
fn calculate_channels(&mut self) {
match self.mode {
RegressionChannelMode::Standard => {
self.calculate_standard_channels();
}
RegressionChannelMode::Percentage => {
self.calculate_percentage_channels();
}
RegressionChannelMode::R2Weighted => {
self.calculate_r2_weighted_channels();
}
}
}
fn calculate_standard_channels(&mut self) {
self.upper = self.middle + self.std_dev_mult * self.std_dev;
self.lower = self.middle - self.std_dev_mult * self.std_dev;
}
fn calculate_percentage_channels(&mut self) {
let percentage_band = self.middle * (self.std_dev_mult / 100.0);
self.upper = self.middle + percentage_band;
self.lower = self.middle - percentage_band;
}
fn calculate_r2_weighted_channels(&mut self) {
let r2_weight = if self.r_squared > 0.0 {
1.0 - self.r_squared } else {
1.0
};
let weighted_std_dev = self.std_dev * (0.5 + r2_weight); self.upper = self.middle + self.std_dev_mult * weighted_std_dev;
self.lower = self.middle - self.std_dev_mult * weighted_std_dev;
}
fn reset_channels(&mut self) {
self.upper = 0.0;
self.lower = 0.0;
self.slope = 0.0;
self.intercept = 0.0;
self.r_squared = 0.0;
self.std_dev = 0.0;
self.trend_direction = TrendDirection::Sideways;
}
pub fn value(&self) -> IndicatorValue {
IndicatorValue::Channel3 {
upper: self.upper,
middle: self.middle,
lower: self.lower,
}
}
pub fn value_tuple(&self) -> (f64, f64, f64) {
(self.upper, self.middle, self.lower)
}
pub fn regression_line(&self) -> f64 {
self.middle
}
pub fn upper(&self) -> f64 {
self.upper
}
pub fn lower(&self) -> f64 {
self.lower
}
pub fn slope(&self) -> f64 {
self.slope
}
pub fn intercept(&self) -> f64 {
self.intercept
}
pub fn r_squared(&self) -> f64 {
self.r_squared
}
pub fn std_dev(&self) -> f64 {
self.std_dev
}
pub fn trend_direction(&self) -> TrendDirection {
self.trend_direction
}
pub fn channel_width(&self) -> f64 {
if self.is_ready() {
self.upper - self.lower
} else {
0.0
}
}
pub fn position_in_channel(&self, price: f64) -> f64 {
if !self.is_ready() || self.upper == self.lower {
0.5 } else {
((price - self.lower) / (self.upper - self.lower)).clamp(0.0, 1.0)
}
}
pub fn is_good_fit(&self, r2_threshold: f64) -> bool {
self.is_ready() && self.r_squared >= r2_threshold
}
pub fn is_strong_trend(&self, r2_threshold: f64, slope_threshold: f64) -> bool {
self.is_good_fit(r2_threshold) && self.slope.abs() >= slope_threshold
}
pub fn predict_price(&self, bars_ahead: usize) -> Option<f64> {
if !self.is_ready() {
return None;
}
let x = (self.period + bars_ahead) as f64;
Some(self.slope * x + self.intercept)
}
pub fn is_ready(&self) -> bool {
self.lr.is_ready() && self.buffer_filled
}
pub fn reset(&mut self) {
self.lr.reset();
self.price_buffer.clear();
self.residuals_buffer.clear();
self.buffer_index = 0;
self.buffer_filled = false;
self.reset_channels();
}
pub fn period(&self) -> usize {
self.period
}
pub fn std_dev_mult(&self) -> f64 {
self.std_dev_mult
}
pub fn mode(&self) -> RegressionChannelMode {
self.mode
}
}
impl Default for RegressionChannels {
fn default() -> Self {
Self::new_standard(20, 2.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_regression_channels_creation() {
let rc = RegressionChannels::new_standard(20, 2.0);
assert!(!rc.is_ready());
assert_eq!(rc.upper(), 0.0);
assert_eq!(rc.lower(), 0.0);
}
#[test]
fn test_regression_channels_warmup() {
let mut rc = RegressionChannels::new_standard(20, 2.0);
for i in 0..25 {
let price = 100.0 + (i as f64 * 0.1).sin() * 5.0;
rc.update_bar(price, price + 1.0, price - 1.0, price, 1000.0);
}
assert!(rc.is_ready());
}
#[test]
fn test_regression_channels_values() {
let mut rc = RegressionChannels::new_standard(20, 2.0);
for i in 0..25 {
let price = 100.0 + i as f64;
rc.update_bar(price, price + 1.0, price - 1.0, price, 1000.0);
}
assert!(rc.upper() >= rc.regression_line());
assert!(rc.regression_line() >= rc.lower());
}
#[test]
fn test_regression_channels_trend() {
let mut rc = RegressionChannels::new_standard(20, 2.0);
for i in 0..25 {
let price = 100.0 + i as f64 * 2.0;
rc.update_bar(price, price + 1.0, price - 1.0, price, 1000.0);
}
assert!(rc.slope() > 0.0);
assert_eq!(rc.trend_direction(), TrendDirection::Uptrend);
}
#[test]
fn test_regression_channels_r2() {
let mut rc = RegressionChannels::new_standard(20, 2.0);
for i in 0..25 {
let price = 100.0 + i as f64;
rc.update_bar(price, price + 1.0, price - 1.0, price, 1000.0);
}
assert!(rc.r_squared() >= 0.0 && rc.r_squared() <= 1.0);
}
#[test]
fn test_regression_channels_reset() {
let mut rc = RegressionChannels::new_standard(20, 2.0);
for i in 0..25 {
rc.update_bar(100.0 + i as f64, 101.0, 99.0, 100.0 + i as f64, 1000.0);
}
rc.reset();
assert!(!rc.is_ready());
assert_eq!(rc.upper(), 0.0);
assert_eq!(rc.lower(), 0.0);
}
}