Skip to main content

quantwave_core/indicators/
gaussian.rs

1use crate::indicators::metadata::{IndicatorMetadata, ParamDef};
2use crate::traits::Next;
3use std::f64::consts::PI;
4
5/// Gaussian Filter
6///
7/// Based on John Ehlers' "Gaussian and Other Low Lag Filters".
8/// A family of low-pass filters with N poles at the same location.
9/// Provides approximately half the lag of an equivalent Butterworth filter.
10#[derive(Debug, Clone)]
11pub struct GaussianFilter {
12    poles: usize,
13    alpha: f64,
14    // a^n
15    alpha_pow: f64,
16    // (1-a)^n
17    one_minus_alpha: f64,
18    price_history: Vec<f64>,
19    filt_history: Vec<f64>,
20    count: usize,
21}
22
23impl GaussianFilter {
24    pub fn new(period: usize, poles: usize) -> Self {
25        let poles = poles.clamp(1, 4);
26        let p = period as f64;
27        let omega = 2.0 * PI / p;
28        // beta = (1 - cos(omega)) / (2^(1/(2N)) - 1)
29        // 1.4142 is sqrt(2), so 2^(1/(2N))
30        let beta = (1.0 - omega.cos()) / (2.0_f64.powf(1.0 / (2.0 * poles as f64)) - 1.0);
31        let alpha = -beta + (beta * beta + 2.0 * beta).sqrt();
32        
33        Self {
34            poles,
35            alpha,
36            alpha_pow: alpha.powi(poles as i32),
37            one_minus_alpha: 1.0 - alpha,
38            price_history: vec![0.0; poles + 1],
39            filt_history: vec![0.0; poles + 1],
40            count: 0,
41        }
42    }
43}
44
45impl Default for GaussianFilter {
46    fn default() -> Self {
47        Self::new(14, 4)
48    }
49}
50
51impl Next<f64> for GaussianFilter {
52    type Output = f64;
53
54    fn next(&mut self, input: f64) -> Self::Output {
55        self.count += 1;
56        
57        let res = match self.poles {
58            1 => {
59                // f = a*g + (1-a)f[1]
60                if self.count < 2 {
61                    input
62                } else {
63                    self.alpha * input + self.one_minus_alpha * self.filt_history[0]
64                }
65            }
66            2 => {
67                // f = a^2*g + 2(1-a)f[1] - (1-a)^2f[2]
68                if self.count < 3 {
69                    input
70                } else {
71                    self.alpha_pow * input
72                        + 2.0 * self.one_minus_alpha * self.filt_history[0]
73                        - self.one_minus_alpha.powi(2) * self.filt_history[1]
74                }
75            }
76            3 => {
77                // f = a^3*g + 3(1-a)f[1] - 3(1-a)^2f[2] + (1-a)^3f[3]
78                if self.count < 4 {
79                    input
80                } else {
81                    self.alpha_pow * input
82                        + 3.0 * self.one_minus_alpha * self.filt_history[0]
83                        - 3.0 * self.one_minus_alpha.powi(2) * self.filt_history[1]
84                        + self.one_minus_alpha.powi(3) * self.filt_history[2]
85                }
86            }
87            4 => {
88                // f = a^4*g + 4(1-a)f[1] - 6(1-a)^2f[2] + 4(1-a)^3f[3] - (1-a)^4f[4]
89                if self.count < 5 {
90                    input
91                } else {
92                    self.alpha_pow * input
93                        + 4.0 * self.one_minus_alpha * self.filt_history[0]
94                        - 6.0 * self.one_minus_alpha.powi(2) * self.filt_history[1]
95                        + 4.0 * self.one_minus_alpha.powi(3) * self.filt_history[2]
96                        - self.one_minus_alpha.powi(4) * self.filt_history[3]
97                }
98            }
99            _ => input,
100        };
101
102        // Shift history
103        for i in (1..self.poles).rev() {
104            self.filt_history[i] = self.filt_history[i - 1];
105            self.price_history[i] = self.price_history[i - 1];
106        }
107        self.filt_history[0] = res;
108        self.price_history[0] = input;
109        
110        res
111    }
112}
113
114pub const GAUSSIAN_FILTER_METADATA: IndicatorMetadata = IndicatorMetadata {
115    name: "GaussianFilter",
116    description: "Multi-pole Gaussian low-pass filter for reduced lag.",
117    params: &[
118        ParamDef {
119            name: "period",
120            default: "14",
121            description: "Critical period",
122        },
123        ParamDef {
124            name: "poles",
125            default: "4",
126            description: "Number of poles (1-4)",
127        },
128    ],
129    formula_source: "https://github.com/lavs9/quantwave/blob/main/references/Ehlers%20Papers/GaussianFilters.pdf",
130    formula_latex: r#"
131\[
132\alpha = -\beta + \sqrt{\beta^2 + 2\beta}
133\]
134\[
135\beta = \frac{1 - \cos(2\pi/P)}{2^{1/(2N)} - 1}
136\]
137"#,
138    gold_standard_file: "gaussian_filter.json",
139    category: "Ehlers DSP",
140};
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use crate::traits::Next;
146    use proptest::prelude::*;
147
148    #[test]
149    fn test_gaussian_basic() {
150        let mut filter = GaussianFilter::new(14, 4);
151        for i in 0..50 {
152            let val = filter.next(100.0);
153            if i > 20 {
154                approx::assert_relative_eq!(val, 100.0, epsilon = 1.0);
155            }
156        }
157    }
158
159    proptest! {
160        #[test]
161        fn test_gaussian_parity(
162            inputs in prop::collection::vec(1.0..100.0, 10..100),
163            poles in 1usize..4usize,
164        ) {
165            let p = 14;
166            let mut filter = GaussianFilter::new(p, poles);
167            let streaming_results: Vec<f64> = inputs.iter().map(|&x| filter.next(x)).collect();
168
169            // Batch implementation
170            let mut batch_results = Vec::with_capacity(inputs.len());
171            let p_f = p as f64;
172            let omega = 2.0 * PI / p_f;
173            let beta = (1.0 - omega.cos()) / (2.0_f64.powf(1.0 / (2.0 * poles as f64)) - 1.0);
174            let alpha = -beta + (beta * beta + 2.0 * beta).sqrt();
175            let alpha_pow = alpha.powi(poles as i32);
176            let oma = 1.0 - alpha;
177
178            let mut f_hist = vec![0.0; poles];
179            
180            for (i, &input) in inputs.iter().enumerate() {
181                let bar = i + 1;
182                let res = if bar < poles + 1 {
183                    input
184                } else {
185                    match poles {
186                        1 => alpha_pow * input + oma * f_hist[0],
187                        2 => alpha_pow * input + 2.0 * oma * f_hist[0] - oma.powi(2) * f_hist[1],
188                        3 => alpha_pow * input + 3.0 * oma * f_hist[0] - 3.0 * oma.powi(2) * f_hist[1] + oma.powi(3) * f_hist[2],
189                        4 => alpha_pow * input + 4.0 * oma * f_hist[0] - 6.0 * oma.powi(2) * f_hist[1] + 4.0 * oma.powi(3) * f_hist[2] - oma.powi(4) * f_hist[3],
190                        _ => input,
191                    }
192                };
193                
194                for j in (1..poles).rev() {
195                    f_hist[j] = f_hist[j-1];
196                }
197                f_hist[0] = res;
198                batch_results.push(res);
199            }
200
201            for (s, b) in streaming_results.iter().zip(batch_results.iter()) {
202                approx::assert_relative_eq!(s, b, epsilon = 1e-10);
203            }
204        }
205    }
206}