Skip to main content

quantwave_core/indicators/
gaussian.rs

1use crate::indicators::metadata::{IndicatorMetadata, ParamDef};
2use crate::traits::Next;
3use std::f64::consts::PI;
4
5/// Gaussian Filter
6///
7/// Based on John Ehlers' "Gaussian and Other Low Lag Filters".
8/// A family of low-pass filters with N poles at the same location.
9/// Provides approximately half the lag of an equivalent Butterworth filter.
10#[derive(Debug, Clone)]
11pub struct GaussianFilter {
12    poles: usize,
13    alpha: f64,
14    // a^n
15    alpha_pow: f64,
16    // (1-a)^n
17    one_minus_alpha: f64,
18    price_history: Vec<f64>,
19    filt_history: Vec<f64>,
20    count: usize,
21}
22
23impl GaussianFilter {
24    pub fn new(period: usize, poles: usize) -> Self {
25        let poles = poles.clamp(1, 4);
26        let p = period as f64;
27        let omega = 2.0 * PI / p;
28        // beta = (1 - cos(omega)) / (2^(1/(2N)) - 1)
29        // 1.4142 is sqrt(2), so 2^(1/(2N))
30        let beta = (1.0 - omega.cos()) / (2.0_f64.powf(1.0 / (2.0 * poles as f64)) - 1.0);
31        let alpha = -beta + (beta * beta + 2.0 * beta).sqrt();
32        
33        Self {
34            poles,
35            alpha,
36            alpha_pow: alpha.powi(poles as i32),
37            one_minus_alpha: 1.0 - alpha,
38            price_history: vec![0.0; poles + 1],
39            filt_history: vec![0.0; poles + 1],
40            count: 0,
41        }
42    }
43}
44
45impl Default for GaussianFilter {
46    fn default() -> Self {
47        Self::new(14, 4)
48    }
49}
50
51impl Next<f64> for GaussianFilter {
52    type Output = f64;
53
54    fn next(&mut self, input: f64) -> Self::Output {
55        self.count += 1;
56        
57        let res = match self.poles {
58            1 => {
59                // f = a*g + (1-a)f[1]
60                if self.count < 2 {
61                    input
62                } else {
63                    self.alpha * input + self.one_minus_alpha * self.filt_history[0]
64                }
65            }
66            2 => {
67                // f = a^2*g + 2(1-a)f[1] - (1-a)^2f[2]
68                if self.count < 3 {
69                    input
70                } else {
71                    self.alpha_pow * input
72                        + 2.0 * self.one_minus_alpha * self.filt_history[0]
73                        - self.one_minus_alpha.powi(2) * self.filt_history[1]
74                }
75            }
76            3 => {
77                // f = a^3*g + 3(1-a)f[1] - 3(1-a)^2f[2] + (1-a)^3f[3]
78                if self.count < 4 {
79                    input
80                } else {
81                    self.alpha_pow * input
82                        + 3.0 * self.one_minus_alpha * self.filt_history[0]
83                        - 3.0 * self.one_minus_alpha.powi(2) * self.filt_history[1]
84                        + self.one_minus_alpha.powi(3) * self.filt_history[2]
85                }
86            }
87            4 => {
88                // f = a^4*g + 4(1-a)f[1] - 6(1-a)^2f[2] + 4(1-a)^3f[3] - (1-a)^4f[4]
89                if self.count < 5 {
90                    input
91                } else {
92                    self.alpha_pow * input
93                        + 4.0 * self.one_minus_alpha * self.filt_history[0]
94                        - 6.0 * self.one_minus_alpha.powi(2) * self.filt_history[1]
95                        + 4.0 * self.one_minus_alpha.powi(3) * self.filt_history[2]
96                        - self.one_minus_alpha.powi(4) * self.filt_history[3]
97                }
98            }
99            _ => input,
100        };
101
102        // Shift history
103        for i in (1..self.poles).rev() {
104            self.filt_history[i] = self.filt_history[i - 1];
105            self.price_history[i] = self.price_history[i - 1];
106        }
107        self.filt_history[0] = res;
108        self.price_history[0] = input;
109        
110        res
111    }
112}
113
114pub const GAUSSIAN_FILTER_METADATA: IndicatorMetadata = IndicatorMetadata {
115    name: "GaussianFilter",
116    description: "Multi-pole Gaussian low-pass filter for reduced lag.",
117    usage: "Use when smooth symmetric price averaging with near-zero phase shift is needed. Works well as a preprocessing step for spectral analysis indicators.",
118    keywords: &["filter", "smoothing", "ehlers", "dsp", "low-pass"],
119    ehlers_summary: "Gaussian filters are the theoretically optimal lowpass filter for minimizing the product of time-domain duration and frequency-domain bandwidth. Ehlers implements them as cascaded pole filters with Gaussian-function-derived coefficients, achieving very smooth output with excellent stopband attenuation.",
120    params: &[
121        ParamDef {
122            name: "period",
123            default: "14",
124            description: "Critical period",
125        },
126        ParamDef {
127            name: "poles",
128            default: "4",
129            description: "Number of poles (1-4)",
130        },
131    ],
132    formula_source: "https://github.com/lavs9/quantwave/blob/main/references/Ehlers%20Papers/GaussianFilters.pdf",
133    formula_latex: r#"
134\[
135\alpha = -\beta + \sqrt{\beta^2 + 2\beta}
136\]
137\[
138\beta = \frac{1 - \cos(2\pi/P)}{2^{1/(2N)} - 1}
139\]
140"#,
141    gold_standard_file: "gaussian_filter.json",
142    category: "Ehlers DSP",
143};
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use crate::traits::Next;
149    use proptest::prelude::*;
150
151    #[test]
152    fn test_gaussian_basic() {
153        let mut filter = GaussianFilter::new(14, 4);
154        for i in 0..50 {
155            let val = filter.next(100.0);
156            if i > 20 {
157                approx::assert_relative_eq!(val, 100.0, epsilon = 1.0);
158            }
159        }
160    }
161
162    proptest! {
163        #[test]
164        fn test_gaussian_parity(
165            inputs in prop::collection::vec(1.0..100.0, 10..100),
166            poles in 1usize..4usize,
167        ) {
168            let p = 14;
169            let mut filter = GaussianFilter::new(p, poles);
170            let streaming_results: Vec<f64> = inputs.iter().map(|&x| filter.next(x)).collect();
171
172            // Batch implementation
173            let mut batch_results = Vec::with_capacity(inputs.len());
174            let p_f = p as f64;
175            let omega = 2.0 * PI / p_f;
176            let beta = (1.0 - omega.cos()) / (2.0_f64.powf(1.0 / (2.0 * poles as f64)) - 1.0);
177            let alpha = -beta + (beta * beta + 2.0 * beta).sqrt();
178            let alpha_pow = alpha.powi(poles as i32);
179            let oma = 1.0 - alpha;
180
181            let mut f_hist = vec![0.0; poles];
182            
183            for (i, &input) in inputs.iter().enumerate() {
184                let bar = i + 1;
185                let res = if bar < poles + 1 {
186                    input
187                } else {
188                    match poles {
189                        1 => alpha_pow * input + oma * f_hist[0],
190                        2 => alpha_pow * input + 2.0 * oma * f_hist[0] - oma.powi(2) * f_hist[1],
191                        3 => alpha_pow * input + 3.0 * oma * f_hist[0] - 3.0 * oma.powi(2) * f_hist[1] + oma.powi(3) * f_hist[2],
192                        4 => alpha_pow * input + 4.0 * oma * f_hist[0] - 6.0 * oma.powi(2) * f_hist[1] + 4.0 * oma.powi(3) * f_hist[2] - oma.powi(4) * f_hist[3],
193                        _ => input,
194                    }
195                };
196                
197                for j in (1..poles).rev() {
198                    f_hist[j] = f_hist[j-1];
199                }
200                f_hist[0] = res;
201                batch_results.push(res);
202            }
203
204            for (s, b) in streaming_results.iter().zip(batch_results.iter()) {
205                approx::assert_relative_eq!(s, b, epsilon = 1e-10);
206            }
207        }
208    }
209}