quantwave_core/indicators/
swiss_army_knife.rs1use crate::indicators::metadata::{IndicatorMetadata, ParamDef};
2use crate::traits::Next;
3
4#[derive(Debug, Clone, Copy, PartialEq)]
6pub enum SwissArmyKnifeMode {
7 EMA,
8 SMA,
9 Gauss,
10 Butterworth,
11 Smooth,
12 HighPass,
13 TwoPoleHighPass,
14 BandPass,
15 BandStop,
16}
17
18#[derive(Debug, Clone)]
24pub struct SwissArmyKnife {
25 mode: SwissArmyKnifeMode,
26 period: usize,
27 delta: f64,
28 c0: f64,
29 c1: f64,
30 b0: f64,
31 b1: f64,
32 b2: f64,
33 a1: f64,
34 a2: f64,
35 x: [f64; 3], f: [f64; 2], history_x: Vec<f64>, count: usize,
39}
40
41impl SwissArmyKnife {
42 pub fn new(mode: SwissArmyKnifeMode, period: usize, delta: f64) -> Self {
43 let mut sak = Self {
44 mode,
45 period,
46 delta,
47 c0: 1.0,
48 c1: 0.0,
49 b0: 1.0,
50 b1: 0.0,
51 b2: 0.0,
52 a1: 0.0,
53 a2: 0.0,
54 x: [0.0; 3],
55 f: [0.0; 2],
56 history_x: Vec::new(),
57 count: 0,
58 };
59 sak.initialize();
60 sak
61 }
62
63 fn initialize(&mut self) {
64 let period_f = self.period as f64;
65 let angle = 2.0 * std::f64::consts::PI / period_f;
66
67 match self.mode {
68 SwissArmyKnifeMode::EMA => {
69 let alpha = (angle.cos() + angle.sin() - 1.0) / angle.cos();
70 self.b0 = alpha;
71 self.a1 = 1.0 - alpha;
72 }
73 SwissArmyKnifeMode::SMA => {
74 self.c1 = 1.0 / period_f;
75 self.b0 = 1.0 / period_f;
76 self.a1 = 1.0;
77 }
78 SwissArmyKnifeMode::Gauss => {
79 let beta = 2.415 * (1.0 - angle.cos());
80 let alpha = -beta + (beta * beta + 2.0 * beta).sqrt();
81 self.c0 = alpha * alpha;
82 self.a1 = 2.0 * (1.0 - alpha);
83 self.a2 = -(1.0 - alpha) * (1.0 - alpha);
84 }
85 SwissArmyKnifeMode::Butterworth => {
86 let beta = 2.415 * (1.0 - angle.cos());
87 let alpha = -beta + (beta * beta + 2.0 * beta).sqrt();
88 self.c0 = alpha * alpha / 4.0;
89 self.b1 = 2.0;
90 self.b2 = 1.0;
91 self.a1 = 2.0 * (1.0 - alpha);
92 self.a2 = -(1.0 - alpha) * (1.0 - alpha);
93 }
94 SwissArmyKnifeMode::Smooth => {
95 self.c0 = 0.25;
96 self.b1 = 2.0;
97 self.b2 = 1.0;
98 }
99 SwissArmyKnifeMode::HighPass => {
100 let alpha = (angle.cos() + angle.sin() - 1.0) / angle.cos();
101 self.c0 = 1.0 - alpha / 2.0;
102 self.b1 = -1.0;
103 self.a1 = 1.0 - alpha;
104 }
105 SwissArmyKnifeMode::TwoPoleHighPass => {
106 let beta = 2.415 * (1.0 - angle.cos());
107 let alpha = -beta + (beta * beta + 2.0 * beta).sqrt();
108 self.c0 = (1.0 - alpha / 2.0) * (1.0 - alpha / 2.0);
109 self.b1 = -2.0;
110 self.b2 = 1.0;
111 self.a1 = 2.0 * (1.0 - alpha);
112 self.a2 = -(1.0 - alpha) * (1.0 - alpha);
113 }
114 SwissArmyKnifeMode::BandPass => {
115 let beta = angle.cos();
116 let gamma = 1.0 / (4.0 * std::f64::consts::PI * self.delta / period_f).cos();
117 let alpha = gamma - (gamma * gamma - 1.0).sqrt();
118 self.c0 = (1.0 - alpha) / 2.0;
119 self.b2 = -1.0;
120 self.a1 = beta * (1.0 + alpha);
121 self.a2 = -alpha;
122 }
123 SwissArmyKnifeMode::BandStop => {
124 let beta = angle.cos();
125 let gamma = 1.0 / (4.0 * std::f64::consts::PI * self.delta / period_f).cos();
126 let alpha = gamma - (gamma * gamma - 1.0).sqrt();
127 self.c0 = (1.0 + alpha) / 2.0;
128 self.b1 = -2.0 * beta;
129 self.b2 = 1.0;
130 self.a1 = beta * (1.0 + alpha);
131 self.a2 = -alpha;
132 }
133 }
134 }
135}
136
137impl Next<f64> for SwissArmyKnife {
138 type Output = f64;
139
140 fn next(&mut self, input: f64) -> Self::Output {
141 self.count += 1;
142 self.x[2] = self.x[1];
143 self.x[1] = self.x[0];
144 self.x[0] = input;
145
146 if self.mode == SwissArmyKnifeMode::SMA {
147 self.history_x.push(input);
148 }
149
150 let filt = if self.count <= self.period {
151 match self.mode {
152 SwissArmyKnifeMode::HighPass | SwissArmyKnifeMode::TwoPoleHighPass => 0.0,
153 _ => input,
154 }
155 } else {
156 let x_n = if self.mode == SwissArmyKnifeMode::SMA {
157 self.history_x[self.count - 1 - self.period]
158 } else {
159 0.0
160 };
161
162 self.c0 * (self.b0 * self.x[0] + self.b1 * self.x[1] + self.b2 * self.x[2])
163 + self.a1 * self.f[0]
164 + self.a2 * self.f[1]
165 - self.c1 * x_n
166 };
167
168 self.f[1] = self.f[0];
169 self.f[0] = filt;
170
171 filt
172 }
173}
174
175pub const SWISS_ARMY_KNIFE_METADATA: IndicatorMetadata = IndicatorMetadata {
176 name: "Swiss Army Knife Indicator",
177 description: "A versatile indicator that can be configured as EMA, SMA, Gaussian, Butterworth, High Pass, Band Pass, or Band Stop filter.",
178 params: &[
179 ParamDef {
180 name: "mode",
181 default: "BandPass",
182 description: "Filter mode (EMA, SMA, Gauss, Butter, Smooth, HP, 2PHP, BP, BS)",
183 },
184 ParamDef {
185 name: "period",
186 default: "20",
187 description: "Filter period",
188 },
189 ParamDef {
190 name: "delta",
191 default: "0.1",
192 description: "Bandwidth parameter for BP and BS modes",
193 },
194 ],
195 formula_source: "https://github.com/lavs9/quantwave/blob/main/references/Ehlers%20Papers/SwissArmyKnifeIndicator.pdf",
196 formula_latex: r#"
197\[
198Filt = c_0(b_0 x_t + b_1 x_{t-1} + b_2 x_{t-2}) + a_1 Filt_{t-1} + a_2 Filt_{t-2} - c_1 x_{t-N}
199\]
200"#,
201 gold_standard_file: "swiss_army_knife.json",
202 category: "Ehlers DSP",
203};
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use crate::traits::Next;
209 use proptest::prelude::*;
210
211 #[test]
212 fn test_sak_ema_basic() {
213 let mut sak = SwissArmyKnife::new(SwissArmyKnifeMode::EMA, 20, 0.1);
214 let inputs = vec![10.0, 11.0, 12.0, 11.0, 10.0];
215 for input in inputs {
216 let val = sak.next(input);
217 assert!(!val.is_nan());
218 }
219 }
220
221 proptest! {
222 #[test]
223 fn test_sak_parity(
224 inputs in prop::collection::vec(1.0..100.0, 30..100),
225 ) {
226 let period = 20;
227 let delta = 0.1;
228 let mode = SwissArmyKnifeMode::Gauss;
229 let mut sak = SwissArmyKnife::new(mode, period, delta);
230
231 let streaming_results: Vec<f64> = inputs.iter().map(|&x| sak.next(x)).collect();
232
233 let mut batch_results = Vec::with_capacity(inputs.len());
235 let mut x = [0.0; 3];
236 let mut f = [0.0; 2];
237
238 let angle = 2.0 * std::f64::consts::PI / (period as f64);
240 let beta = 2.415 * (1.0 - angle.cos());
241 let alpha = -beta + (beta * beta + 2.0 * beta).sqrt();
242 let c0 = alpha * alpha;
243 let a1 = 2.0 * (1.0 - alpha);
244 let a2 = -(1.0 - alpha) * (1.0 - alpha);
245
246 for (i, &input) in inputs.iter().enumerate() {
247 x[2] = x[1];
248 x[1] = x[0];
249 x[0] = input;
250
251 let filt = if i + 1 <= period {
252 input
253 } else {
254 c0 * x[0] + a1 * f[0] + a2 * f[1]
255 };
256
257 f[1] = f[0];
258 f[0] = filt;
259 batch_results.push(filt);
260 }
261
262 for (s, b) in streaming_results.iter().zip(batch_results.iter()) {
263 approx::assert_relative_eq!(s, b, epsilon = 1e-10);
264 }
265 }
266 }
267}