Skip to main content

quantwave_core/indicators/
griffiths_predictor.rs

1use crate::indicators::metadata::{IndicatorMetadata, ParamDef};
2use crate::traits::Next;
3use crate::indicators::high_pass::HighPass;
4use crate::indicators::super_smoother::SuperSmoother;
5use std::collections::VecDeque;
6
7/// Griffiths Predictor
8///
9/// Based on John Ehlers' "Linear Predictive Filters And Instantaneous Frequency" (TASC January 2025).
10/// Uses an adaptive LMS (Griffiths) algorithm to predict future signal values.
11#[derive(Debug, Clone)]
12pub struct GriffithsPredictor {
13    length: usize,
14    bars_fwd: usize,
15    mu: f64,
16    hp: HighPass,
17    ss: SuperSmoother,
18    peak: f64,
19    signal_window: VecDeque<f64>,
20    coef: Vec<f64>,
21}
22
23impl GriffithsPredictor {
24    pub fn new(lower_bound: usize, upper_bound: usize, length: usize, bars_fwd: usize) -> Self {
25        Self {
26            length,
27            bars_fwd,
28            mu: 1.0 / (length as f64),
29            hp: HighPass::new(upper_bound),
30            ss: SuperSmoother::new(lower_bound),
31            peak: 0.1,
32            signal_window: VecDeque::with_capacity(length + 1),
33            coef: vec![0.0; length + 1], // 1-indexed logic compatibility
34        }
35    }
36}
37
38impl Default for GriffithsPredictor {
39    fn default() -> Self {
40        Self::new(18, 40, 18, 2)
41    }
42}
43
44impl Next<f64> for GriffithsPredictor {
45    type Output = f64;
46
47    fn next(&mut self, input: f64) -> Self::Output {
48        let hp_val = self.hp.next(input);
49        let lp_val = self.ss.next(hp_val);
50
51        // Peak detection
52        self.peak *= 0.991;
53        if lp_val.abs() > self.peak {
54            self.peak = lp_val.abs();
55        }
56
57        let signal = if self.peak != 0.0 {
58            lp_val / self.peak
59        } else {
60            0.0
61        };
62
63        self.signal_window.push_front(signal);
64        if self.signal_window.len() > self.length {
65            self.signal_window.pop_back();
66        }
67
68        if self.signal_window.len() < self.length {
69            return 0.0;
70        }
71
72        // Current signal is at index 0 (latest)
73        // Previous signals are at indices 1..Length-1
74        // Ehlers' XX[Length] is current signal.
75        // XX[Length - count] is previous signals.
76        // XX[Length - 1] = window[1]
77        // XX[Length - length] = window[length]? Wait.
78        
79        // Let's use Ehlers' indexing directly by copying to a temp vector
80        let mut xx = vec![0.0; self.length + 1];
81        for i in 1..=self.length {
82            xx[i] = self.signal_window[self.length - i];
83        }
84
85        let mut x_bar = 0.0;
86        for count in 1..=self.length {
87            x_bar += xx[self.length - count] * self.coef[count];
88        }
89
90        for count in 1..=self.length {
91            self.coef[count] += self.mu * (xx[self.length] - x_bar) * xx[self.length - count];
92        }
93
94        // Prediction
95        let mut x_pred = 0.0;
96        let mut xx_temp = xx.clone();
97        for _advance in 1..=self.bars_fwd {
98            x_pred = 0.0;
99            for count in 1..=self.length {
100                x_pred += xx_temp[self.length + 1 - count] * self.coef[count];
101            }
102            
103            // Shift
104            for count in 1..self.length {
105                xx_temp[count] = xx_temp[count + 1];
106            }
107            xx_temp[self.length] = x_pred;
108        }
109
110        x_pred
111    }
112}
113
114pub const GRIFFITHS_PREDICTOR_METADATA: IndicatorMetadata = IndicatorMetadata {
115    name: "GriffithsPredictor",
116    description: "Adaptive LMS linear predictive filter for signal forecasting.",
117    params: &[
118        ParamDef {
119            name: "lower_bound",
120            default: "18",
121            description: "Lower frequency bound (SS length)",
122        },
123        ParamDef {
124            name: "upper_bound",
125            default: "40",
126            description: "Upper frequency bound (HP length)",
127        },
128        ParamDef {
129            name: "length",
130            default: "18",
131            description: "LMS filter length",
132        },
133        ParamDef {
134            name: "bars_fwd",
135            default: "2",
136            description: "Number of bars to predict forward",
137        },
138    ],
139    formula_source: "https://github.com/lavs9/quantwave/blob/main/references/traderstipsreference/TRADERS’%20TIPS%20-%20JANUARY%202025.html",
140    formula_latex: r#"
141\[
142\mu = 1/L
143\]
144\[
145\bar{x} = \sum_{i=1}^L xx_{L-i} \cdot coef_i
146\]
147\[
148coef_i = coef_i + \mu(xx_L - \bar{x})xx_{L-i}
149\]
150"#,
151    gold_standard_file: "griffiths_predictor.json",
152    category: "Ehlers DSP",
153};
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use crate::traits::Next;
159    use proptest::prelude::*;
160
161    #[test]
162    fn test_griffiths_predictor_basic() {
163        let mut gp = GriffithsPredictor::new(18, 40, 18, 2);
164        for i in 0..100 {
165            let val = gp.next(100.0 + (i as f64 * 0.1).sin());
166            assert!(!val.is_nan());
167        }
168    }
169
170    proptest! {
171        #[test]
172        fn test_griffiths_predictor_parity(
173            inputs in prop::collection::vec(1.0..100.0, 100..200),
174        ) {
175            let lb = 18;
176            let ub = 40;
177            let length = 18;
178            let bars_fwd = 2;
179            let mut gp = GriffithsPredictor::new(lb, ub, length, bars_fwd);
180            let streaming_results: Vec<f64> = inputs.iter().map(|&x| gp.next(x)).collect();
181
182            // Batch implementation
183            let mut batch_results = Vec::with_capacity(inputs.len());
184            let mut hp = HighPass::new(ub);
185            let mut ss = SuperSmoother::new(lb);
186            let lp_vals: Vec<f64> = inputs.iter().map(|&x| ss.next(hp.next(x))).collect();
187
188            let mut peak = 0.1;
189            let mut signals = Vec::new();
190            let mut coef = vec![0.0; length + 1];
191            let mu = 1.0 / length as f64;
192
193            for (i, &lp_val) in lp_vals.iter().enumerate() {
194                peak *= 0.991;
195                if lp_val.abs() > peak {
196                    peak = lp_val.abs();
197                }
198                let signal = if peak != 0.0 { lp_val / peak } else { 0.0 };
199                signals.push(signal);
200
201                if signals.len() < length {
202                    batch_results.push(0.0);
203                    continue;
204                }
205
206                let mut xx = vec![0.0; length + 1];
207                for j in 1..=length {
208                    xx[j] = signals[i - (length - j)];
209                }
210
211                let mut x_bar = 0.0;
212                for count in 1..=length {
213                    x_bar += xx[length - count] * coef[count];
214                }
215
216                for count in 1..=length {
217                    coef[count] += mu * (xx[length] - x_bar) * xx[length - count];
218                }
219
220                let mut x_pred = 0.0;
221                let mut xx_temp = xx.clone();
222                for _advance in 1..=bars_fwd {
223                    x_pred = 0.0;
224                    for count in 1..=length {
225                        x_pred += xx_temp[length + 1 - count] * coef[count];
226                    }
227                    for count in 1..length {
228                        xx_temp[count] = xx_temp[count + 1];
229                    }
230                    xx_temp[length] = x_pred;
231                }
232                batch_results.push(x_pred);
233            }
234
235            for (s, b) in streaming_results.iter().zip(batch_results.iter()) {
236                approx::assert_relative_eq!(s, b, epsilon = 1e-10);
237            }
238        }
239    }
240}