Skip to main content

quantwave_core/indicators/
emd.rs

1use crate::indicators::metadata::{IndicatorMetadata, ParamDef};
2use crate::indicators::smoothing::SMA;
3use crate::traits::Next;
4use std::f64::consts::PI;
5
6/// Empirical Mode Decomposition (EMD)
7///
8/// Based on John Ehlers' "Empirical Mode Decomposition" (2010).
9/// EMD decomposes price data into a cycle component (via Bandpass filter)
10/// and a trend component (via averaging the bandpass output).
11/// It also provides thresholds based on averaged peaks/valleys to identify market modes.
12///
13/// Returns (Trend, UpperThreshold, LowerThreshold).
14#[derive(Debug, Clone)]
15pub struct EMD {
16    alpha: f64,
17    beta: f64,
18    fraction: f64,
19    price_prev1: f64,
20    price_prev2: f64,
21    bp_history: [f64; 2],
22    mean_sma: SMA,
23    peak_sma: SMA,
24    valley_sma: SMA,
25    peak: f64,
26    valley: f64,
27    count: usize,
28}
29
30impl EMD {
31    pub fn new(period: usize, delta: f64, fraction: f64) -> Self {
32        // beta = Cosine(360 / Period);
33        let beta = (2.0 * PI / period as f64).cos();
34        // gamma = 1 / Cosine(720*delta / Period);
35        let gamma = 1.0 / (4.0 * PI * delta / period as f64).cos();
36        // alpha = gamma - SquareRoot(gamma*gamma - 1);
37        let alpha = gamma - (gamma * gamma - 1.0).sqrt();
38
39        Self {
40            alpha,
41            beta,
42            fraction,
43            price_prev1: 0.0,
44            price_prev2: 0.0,
45            bp_history: [0.0; 2],
46            mean_sma: SMA::new(2 * period),
47            peak_sma: SMA::new(50),
48            valley_sma: SMA::new(50),
49            peak: 0.0,
50            valley: 0.0,
51            count: 0,
52        }
53    }
54}
55
56impl Next<f64> for EMD {
57    type Output = (f64, f64, f64); // (Trend/Mean, Upper, Lower)
58
59    fn next(&mut self, input: f64) -> Self::Output {
60        self.count += 1;
61
62        // BP = .5*(1 - alpha)*(Price - Price[2]) + beta*(1 + alpha)*BP[1] - alpha*BP[2];
63        let bp = 0.5 * (1.0 - self.alpha) * (input - self.price_prev2)
64            + self.beta * (1.0 + self.alpha) * self.bp_history[0]
65            - self.alpha * self.bp_history[1];
66
67        // Mean = Average(BP, 2*Period);
68        let mean = self.mean_sma.next(bp);
69
70        // Peak/Valley logic
71        // If BP[1] > BP and BP[1] > BP[2] Then Peak = BP[1];
72        if self.count > 2 {
73            if self.bp_history[0] > bp && self.bp_history[0] > self.bp_history[1] {
74                self.peak = self.bp_history[0];
75            }
76            if self.bp_history[0] < bp && self.bp_history[0] < self.bp_history[1] {
77                self.valley = self.bp_history[0];
78            }
79        }
80
81        // AvgPeak = Average(Peak, 50);
82        let avg_peak = self.peak_sma.next(self.peak);
83        let avg_valley = self.valley_sma.next(self.valley);
84
85        // Shift history
86        self.bp_history[1] = self.bp_history[0];
87        self.bp_history[0] = bp;
88
89        self.price_prev2 = self.price_prev1;
90        self.price_prev1 = input;
91
92        (mean, self.fraction * avg_peak, self.fraction * avg_valley)
93    }
94}
95
96pub const EMD_METADATA: IndicatorMetadata = IndicatorMetadata {
97    name: "EMD",
98    description: "Empirical Mode Decomposition separates cycles from trends using bandpass filtering and identifies market modes via adaptive thresholds.",
99    params: &[
100        ParamDef {
101            name: "period",
102            default: "20",
103            description: "Bandpass center period",
104        },
105        ParamDef {
106            name: "delta",
107            default: "0.5",
108            description: "Bandwidth half-width",
109        },
110        ParamDef {
111            name: "fraction",
112            default: "0.1",
113            description: "Threshold multiplier for peaks/valleys",
114        },
115    ],
116    formula_source: "https://github.com/lavs9/quantwave/blob/main/references/Ehlers%20Papers/EmpiricalModeDecomposition.pdf",
117    formula_latex: r#"
118\[
119\beta = \cos\left(\frac{360}{P}\right), \gamma = \frac{1}{\cos\left(\frac{720\delta}{P}\right)}, \alpha = \gamma - \sqrt{\gamma^2 - 1}
120\]
121\[
122BP = 0.5(1 - \alpha)(Price - Price_{t-2}) + \beta(1 + \alpha)BP_{t-1} - \alpha BP_{t-2}
123\]
124\[
125Mean = \text{SMA}(BP, 2P)
126\]
127\[
128Threshold = \text{Fraction} \cdot \text{SMA}(\text{Peak/Valley}, 50)
129\]
130"#,
131    gold_standard_file: "emd.json",
132    category: "Ehlers DSP",
133};
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use crate::traits::Next;
139    use proptest::prelude::*;
140
141    #[test]
142    fn test_emd_basic() {
143        let mut emd = EMD::new(20, 0.5, 0.1);
144        let inputs = vec![10.0, 11.0, 12.0, 13.0, 14.0, 15.0];
145        for input in inputs {
146            let (m, u, l) = emd.next(input);
147            assert!(!m.is_nan());
148            assert!(!u.is_nan());
149            assert!(!l.is_nan());
150        }
151    }
152
153    proptest! {
154        #[test]
155        fn test_emd_parity(
156            inputs in prop::collection::vec(1.0..100.0, 150..250),
157        ) {
158            let period = 20;
159            let delta = 0.5;
160            let fraction = 0.1;
161            let mut emd = EMD::new(period, delta, fraction);
162            let streaming_results: Vec<(f64, f64, f64)> = inputs.iter().map(|&x| emd.next(x)).collect();
163
164            // Batch implementation
165            let mut batch_results = Vec::with_capacity(inputs.len());
166            let beta = (2.0 * PI / period as f64).cos();
167            let gamma = 1.0 / (4.0 * PI * delta / period as f64).cos();
168            let alpha = gamma - (gamma * gamma - 1.0).sqrt();
169
170            let mut price_hist = vec![0.0; inputs.len() + 4];
171            let mut bp_hist = vec![0.0; inputs.len() + 4];
172            let mut peak = 0.0;
173            let mut valley = 0.0;
174            let mut peak_hist = Vec::new();
175            let mut valley_hist = Vec::new();
176
177            for (i, &input) in inputs.iter().enumerate() {
178                let bar = i + 1;
179                let idx = i + 2;
180                price_hist[idx] = input;
181
182                let bp = 0.5 * (1.0 - alpha) * (price_hist[idx] - price_hist[idx-2])
183                    + beta * (1.0 + alpha) * bp_hist[idx-1]
184                    - alpha * bp_hist[idx-2];
185                bp_hist[idx] = bp;
186
187                let mut mean_sum = 0.0;
188                let mean_len = (2 * period).min(bar);
189                for j in 0..mean_len {
190                    mean_sum += bp_hist[idx-j];
191                }
192                let mean = mean_sum / mean_len as f64;
193
194                if bar > 2 {
195                    if bp_hist[idx-1] > bp && bp_hist[idx-1] > bp_hist[idx-2] {
196                        peak = bp_hist[idx-1];
197                    }
198                    if bp_hist[idx-1] < bp && bp_hist[idx-1] < bp_hist[idx-2] {
199                        valley = bp_hist[idx-1];
200                    }
201                }
202                peak_hist.push(peak);
203                valley_hist.push(valley);
204
205                let mut p_sum = 0.0;
206                let p_len = 50.min(bar);
207                for j in 0..p_len {
208                    p_sum += peak_hist[i-j];
209                }
210                let avg_p = p_sum / p_len as f64;
211
212                let mut v_sum = 0.0;
213                for j in 0..p_len {
214                    v_sum += valley_hist[i-j];
215                }
216                let avg_v = v_sum / p_len as f64;
217
218                batch_results.push((mean, fraction * avg_p, fraction * avg_v));
219            }
220
221            for (s, b) in streaming_results.iter().zip(batch_results.iter()) {
222                approx::assert_relative_eq!(s.0, b.0, epsilon = 1e-10);
223                approx::assert_relative_eq!(s.1, b.1, epsilon = 1e-10);
224                approx::assert_relative_eq!(s.2, b.2, epsilon = 1e-10);
225            }
226        }
227    }
228}