quantwave_core/indicators/
gaussian.rs1use crate::indicators::metadata::{IndicatorMetadata, ParamDef};
2use crate::traits::Next;
3use std::f64::consts::PI;
4
5#[derive(Debug, Clone)]
11pub struct GaussianFilter {
12 poles: usize,
13 alpha: f64,
14 alpha_pow: f64,
16 one_minus_alpha: f64,
18 price_history: Vec<f64>,
19 filt_history: Vec<f64>,
20 count: usize,
21}
22
23impl GaussianFilter {
24 pub fn new(period: usize, poles: usize) -> Self {
25 let poles = poles.clamp(1, 4);
26 let p = period as f64;
27 let omega = 2.0 * PI / p;
28 let beta = (1.0 - omega.cos()) / (2.0_f64.powf(1.0 / (2.0 * poles as f64)) - 1.0);
31 let alpha = -beta + (beta * beta + 2.0 * beta).sqrt();
32
33 Self {
34 poles,
35 alpha,
36 alpha_pow: alpha.powi(poles as i32),
37 one_minus_alpha: 1.0 - alpha,
38 price_history: vec![0.0; poles + 1],
39 filt_history: vec![0.0; poles + 1],
40 count: 0,
41 }
42 }
43}
44
45impl Default for GaussianFilter {
46 fn default() -> Self {
47 Self::new(14, 4)
48 }
49}
50
51impl Next<f64> for GaussianFilter {
52 type Output = f64;
53
54 fn next(&mut self, input: f64) -> Self::Output {
55 self.count += 1;
56
57 let res = match self.poles {
58 1 => {
59 if self.count < 2 {
61 input
62 } else {
63 self.alpha * input + self.one_minus_alpha * self.filt_history[0]
64 }
65 }
66 2 => {
67 if self.count < 3 {
69 input
70 } else {
71 self.alpha_pow * input
72 + 2.0 * self.one_minus_alpha * self.filt_history[0]
73 - self.one_minus_alpha.powi(2) * self.filt_history[1]
74 }
75 }
76 3 => {
77 if self.count < 4 {
79 input
80 } else {
81 self.alpha_pow * input
82 + 3.0 * self.one_minus_alpha * self.filt_history[0]
83 - 3.0 * self.one_minus_alpha.powi(2) * self.filt_history[1]
84 + self.one_minus_alpha.powi(3) * self.filt_history[2]
85 }
86 }
87 4 => {
88 if self.count < 5 {
90 input
91 } else {
92 self.alpha_pow * input
93 + 4.0 * self.one_minus_alpha * self.filt_history[0]
94 - 6.0 * self.one_minus_alpha.powi(2) * self.filt_history[1]
95 + 4.0 * self.one_minus_alpha.powi(3) * self.filt_history[2]
96 - self.one_minus_alpha.powi(4) * self.filt_history[3]
97 }
98 }
99 _ => input,
100 };
101
102 for i in (1..self.poles).rev() {
104 self.filt_history[i] = self.filt_history[i - 1];
105 self.price_history[i] = self.price_history[i - 1];
106 }
107 self.filt_history[0] = res;
108 self.price_history[0] = input;
109
110 res
111 }
112}
113
114pub const GAUSSIAN_FILTER_METADATA: IndicatorMetadata = IndicatorMetadata {
115 name: "GaussianFilter",
116 description: "Multi-pole Gaussian low-pass filter for reduced lag.",
117 params: &[
118 ParamDef {
119 name: "period",
120 default: "14",
121 description: "Critical period",
122 },
123 ParamDef {
124 name: "poles",
125 default: "4",
126 description: "Number of poles (1-4)",
127 },
128 ],
129 formula_source: "https://github.com/lavs9/quantwave/blob/main/references/Ehlers%20Papers/GaussianFilters.pdf",
130 formula_latex: r#"
131\[
132\alpha = -\beta + \sqrt{\beta^2 + 2\beta}
133\]
134\[
135\beta = \frac{1 - \cos(2\pi/P)}{2^{1/(2N)} - 1}
136\]
137"#,
138 gold_standard_file: "gaussian_filter.json",
139 category: "Ehlers DSP",
140};
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use crate::traits::Next;
146 use proptest::prelude::*;
147
148 #[test]
149 fn test_gaussian_basic() {
150 let mut filter = GaussianFilter::new(14, 4);
151 for i in 0..50 {
152 let val = filter.next(100.0);
153 if i > 20 {
154 approx::assert_relative_eq!(val, 100.0, epsilon = 1.0);
155 }
156 }
157 }
158
159 proptest! {
160 #[test]
161 fn test_gaussian_parity(
162 inputs in prop::collection::vec(1.0..100.0, 10..100),
163 poles in 1usize..4usize,
164 ) {
165 let p = 14;
166 let mut filter = GaussianFilter::new(p, poles);
167 let streaming_results: Vec<f64> = inputs.iter().map(|&x| filter.next(x)).collect();
168
169 let mut batch_results = Vec::with_capacity(inputs.len());
171 let p_f = p as f64;
172 let omega = 2.0 * PI / p_f;
173 let beta = (1.0 - omega.cos()) / (2.0_f64.powf(1.0 / (2.0 * poles as f64)) - 1.0);
174 let alpha = -beta + (beta * beta + 2.0 * beta).sqrt();
175 let alpha_pow = alpha.powi(poles as i32);
176 let oma = 1.0 - alpha;
177
178 let mut f_hist = vec![0.0; poles];
179
180 for (i, &input) in inputs.iter().enumerate() {
181 let bar = i + 1;
182 let res = if bar < poles + 1 {
183 input
184 } else {
185 match poles {
186 1 => alpha_pow * input + oma * f_hist[0],
187 2 => alpha_pow * input + 2.0 * oma * f_hist[0] - oma.powi(2) * f_hist[1],
188 3 => alpha_pow * input + 3.0 * oma * f_hist[0] - 3.0 * oma.powi(2) * f_hist[1] + oma.powi(3) * f_hist[2],
189 4 => alpha_pow * input + 4.0 * oma * f_hist[0] - 6.0 * oma.powi(2) * f_hist[1] + 4.0 * oma.powi(3) * f_hist[2] - oma.powi(4) * f_hist[3],
190 _ => input,
191 }
192 };
193
194 for j in (1..poles).rev() {
195 f_hist[j] = f_hist[j-1];
196 }
197 f_hist[0] = res;
198 batch_results.push(res);
199 }
200
201 for (s, b) in streaming_results.iter().zip(batch_results.iter()) {
202 approx::assert_relative_eq!(s, b, epsilon = 1e-10);
203 }
204 }
205 }
206}