Skip to main content

quantwave_core/indicators/
simple_predictor.rs

1use crate::indicators::metadata::{IndicatorMetadata, ParamDef};
2use crate::traits::Next;
3use crate::indicators::high_pass::HighPass;
4use crate::indicators::super_smoother::SuperSmoother;
5
6/// Simple 2-Pole Predictor
7///
8/// Based on John Ehlers' "Linear Predictive Filters And Instantaneous Frequency" (TASC January 2025).
9/// A non-adaptive 2-pole linear predictive filter using a fixed Q factor.
10#[derive(Debug, Clone)]
11pub struct SimplePredictor {
12    hp: HighPass,
13    ss: SuperSmoother,
14    q: f64,
15    signal_history: [f64; 2],
16    count: usize,
17}
18
19impl SimplePredictor {
20    pub fn new(hp_len: usize, lp_len: usize, q: f64) -> Self {
21        Self {
22            hp: HighPass::new(hp_len),
23            ss: SuperSmoother::new(lp_len),
24            q,
25            signal_history: [0.0; 2],
26            count: 0,
27        }
28    }
29}
30
31impl Default for SimplePredictor {
32    fn default() -> Self {
33        Self::new(15, 30, 0.35)
34    }
35}
36
37impl Next<f64> for SimplePredictor {
38    type Output = f64;
39
40    fn next(&mut self, input: f64) -> Self::Output {
41        self.count += 1;
42        let signal = self.ss.next(self.hp.next(input));
43
44        let c1 = 1.8 * self.q;
45        let c2 = -self.q * self.q;
46        let sum = 1.0 - c1 - c2;
47
48        let res = if self.count < 3 {
49            signal
50        } else {
51            // Predict = (Signal - c1*Signal[1] - c2*Signal[2]) / sum
52            // Note: Pine script: 
53            // c0 = (1.0 / sum) * Signal
54            // c1_ = (c1 / sum) * Signal[1]
55            // c2_ = (c2 / sum) * Signal[2]
56            // Predict = c0 - c1_ - c2_
57            (signal - c1 * self.signal_history[0] - c2 * self.signal_history[1]) / sum
58        };
59
60        self.signal_history[1] = self.signal_history[0];
61        self.signal_history[0] = signal;
62        
63        res
64    }
65}
66
67pub const SIMPLE_PREDICTOR_METADATA: IndicatorMetadata = IndicatorMetadata {
68    name: "SimplePredictor",
69    description: "A fixed-coefficient 2-pole linear predictive filter.",
70    params: &[
71        ParamDef {
72            name: "hp_len",
73            default: "15",
74            description: "HighPass filter length",
75        },
76        ParamDef {
77            name: "lp_len",
78            default: "30",
79            description: "LowPass (SuperSmoother) length",
80        },
81        ParamDef {
82            name: "q",
83            default: "0.35",
84            description: "Damping/Predictor coefficient",
85        },
86    ],
87    formula_source: "https://github.com/lavs9/quantwave/blob/main/references/traderstipsreference/TRADERS’%20TIPS%20-%20JANUARY%202025.html",
88    formula_latex: r#"
89\[
90Predict = \frac{Signal - 1.8Q \cdot Signal_{t-1} + Q^2 \cdot Signal_{t-2}}{1 - 1.8Q + Q^2}
91\]
92"#,
93    gold_standard_file: "simple_predictor.json",
94    category: "Ehlers DSP",
95};
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use crate::traits::Next;
101    use proptest::prelude::*;
102
103    #[test]
104    fn test_simple_predictor_basic() {
105        let mut sp = SimplePredictor::new(15, 30, 0.35);
106        for i in 0..50 {
107            let val = sp.next(100.0 + i as f64);
108            assert!(!val.is_nan());
109        }
110    }
111
112    proptest! {
113        #[test]
114        fn test_simple_predictor_parity(
115            inputs in prop::collection::vec(1.0..100.0, 50..100),
116        ) {
117            let hp_len = 15;
118            let lp_len = 30;
119            let q = 0.35;
120            let mut sp = SimplePredictor::new(hp_len, lp_len, q);
121            let streaming_results: Vec<f64> = inputs.iter().map(|&x| sp.next(x)).collect();
122
123            // Batch implementation
124            let mut batch_results = Vec::with_capacity(inputs.len());
125            let mut hp = HighPass::new(hp_len);
126            let mut ss = SuperSmoother::new(lp_len);
127            let signal_vals: Vec<f64> = inputs.iter().map(|&x| ss.next(hp.next(x))).collect();
128
129            let c1 = 1.8 * q;
130            let c2 = -q * q;
131            let sum = 1.0 - c1 - c2;
132
133            for (i, &signal) in signal_vals.iter().enumerate() {
134                let bar = i + 1;
135                let res = if bar < 3 {
136                    signal
137                } else {
138                    (signal - c1 * signal_vals[i-1] - c2 * signal_vals[i-2]) / sum
139                };
140                batch_results.push(res);
141            }
142
143            for (s, b) in streaming_results.iter().zip(batch_results.iter()) {
144                approx::assert_relative_eq!(s, b, epsilon = 1e-10);
145            }
146        }
147    }
148}