Skip to main content

pot_head/filters/
ema.rs

1/// Exponential Moving Average filter state
2#[derive(Debug, Clone, Copy)]
3pub struct EmaFilter {
4    previous: f32,
5    initialized: bool,
6}
7
8impl EmaFilter {
9    /// Create new EMA filter with uninitialized state
10    pub const fn new() -> Self {
11        Self {
12            previous: 0.0,
13            initialized: false,
14        }
15    }
16
17    /// Apply EMA filter: output = alpha * input + (1 - alpha) * previous
18    ///
19    /// First call initializes the filter to the input value.
20    pub fn apply(&mut self, input: f32, alpha: f32) -> f32 {
21        debug_assert!(
22            alpha > 0.0 && alpha <= 1.0,
23            "EMA alpha must be in range (0.0, 1.0], got {}",
24            alpha
25        );
26
27        if !self.initialized {
28            self.previous = input;
29            self.initialized = true;
30            return input;
31        }
32
33        let output = alpha * input + (1.0 - alpha) * self.previous;
34        self.previous = output;
35        output
36    }
37
38    /// Reset filter state, seeding it to a known value.
39    ///
40    /// The next `apply()` call will start from `value` rather than cold-starting at 0.
41    pub fn reset(&mut self, value: f32) {
42        self.previous = value;
43        self.initialized = true;
44    }
45}
46
47impl Default for EmaFilter {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56
57    #[test]
58    fn first_call_returns_input() {
59        let mut filter = EmaFilter::new();
60        assert_eq!(filter.apply(0.5, 0.3), 0.5);
61    }
62
63    #[test]
64    fn applies_smoothing() {
65        let mut filter = EmaFilter::new();
66        filter.apply(0.0, 0.3);
67
68        // Step input from 0.0 to 1.0
69        let output = filter.apply(1.0, 0.3);
70        // output = 0.3 * 1.0 + 0.7 * 0.0 = 0.3
71        assert!((output - 0.3).abs() < 1e-6);
72    }
73
74    #[test]
75    fn lower_alpha_more_smoothing() {
76        let mut filter_low = EmaFilter::new();
77        let mut filter_high = EmaFilter::new();
78
79        filter_low.apply(0.0, 0.1);
80        filter_high.apply(0.0, 0.9);
81
82        let out_low = filter_low.apply(1.0, 0.1);
83        let out_high = filter_high.apply(1.0, 0.9);
84
85        // Higher alpha should respond more to new input
86        assert!(out_high > out_low);
87    }
88
89    #[test]
90    fn reset_seeds_to_value() {
91        let mut filter = EmaFilter::new();
92        filter.apply(0.5, 0.3);
93        filter.apply(0.7, 0.3);
94
95        filter.reset(0.8);
96
97        // After reset, first call starts from seeded value (0.3 * 1.0 + 0.7 * 0.8)
98        let expected = 0.3 * 1.0 + 0.7 * 0.8;
99        assert!((filter.apply(1.0, 0.3) - expected).abs() < 1e-6);
100    }
101}