pluto_sdr/
filter.rs

1// Author: Roman Hayn
2// MIT License 2023
3
4use std::f32::consts::PI;
5
6/// Create a FIR Filter, which can be applied to signals (simply `Vec<f32>`).
7/// ```
8/// use pluto_sdr::filter::Filter;
9/// // arbitrary signal
10/// let signal = vec![1., 4., -3., 5., 10., -19., -4., 10., -2., -1.];
11/// // coeff describe a 4 point moving average
12/// let lpf_coefficients = vec![0.25, 0.25, 0.25, 0.25];
13/// // create the low pass filter
14/// let lpf = Filter::new(lpf_coefficients.into());
15///
16/// // filter the signal
17/// let filtered = lpf.filter_windowed(&signal);
18/// println!("{:?}", filtered);
19/// ```
20#[derive(PartialEq, Debug)]
21pub struct Filter {
22    coeff: Vec<f32>,
23}
24
25impl Filter {
26    /// Create a new FIR Filter.
27    /// `coeff` is of type `Vec<f32>` and contains the FIR Filters Coefficients: `[b_0, b_1, b_2, ... ]`.
28    pub fn new(coeff: Vec<f32>) -> Self {
29        Filter { coeff }
30    }
31
32    /// The Filter size is just the size of its coefficients
33    pub fn len(&self) -> usize {
34        self.coeff.len()
35    }
36
37    /// recall that windowed convolution produces a signal of length `N - M + 1`
38    /// where `N` is the length of signal 1: `signal`;
39    ///   and `M` is the length of signal 2: `self.coeff`
40    /// The `signal` should be longer than the internal filter coefficients.
41    pub fn filter_windowed(&self, signal: &Vec<f32>) -> Vec<f32> {
42        signal
43            .windows(self.coeff.len())
44            .map(|window| window.iter().zip(&self.coeff).map(|(x, y)| x * y).sum())
45            .collect()
46    }
47
48    /// Creates a moving average filter, where all coefficients are the same value of `1 / N`,
49    /// where the size of the filter is `N`.
50    pub fn create_moving_average(size: usize) -> Self {
51        let coefficients = vec![1.0 / size as f32; size];
52        Self {
53            coeff: coefficients,
54        }
55    }
56
57    /// Creates a Matched Filter / Square-root raised cosine filter.
58    /// A Root Raised Cosine / Matched Filter has very low ISI / Inter Symbol Interference.
59    /// `n_half` is half the length of the pulse (0,1,2...(n_half*2+1)).
60    /// -> we want one sample in the middle of the pulse: `sinc(0)`, that is why the length of this
61    /// filter is always uneven.
62    /// `b` is beta / sinc roll-off and `m` is the oversampling factor.
63    /// (Ported from Matlab: srrc.m, see Software Receiver Design book)
64    pub fn create_srrc(n_half: usize, b: f32, m: usize) -> Self {
65        // create a vector of size: N*2+1
66        let mut srrc = Vec::with_capacity(n_half * 2 * m + 1);
67        // offset the sinc function / rightshift
68        let offset = n_half * m;
69        // pre calculate a factor outside of loop
70        let factor = 4.0 * b / f32::sqrt(m as f32);
71
72        // create first half of signal, excluding middle sample
73        for i in 0..(n_half * m) {
74            let t = i as f32 - offset as f32;
75
76            let cos_calc = f32::cos((1.0 + b) * PI * t / (m as f32));
77            let sin_calc = f32::sin((1.0 - b) * PI * t / (m as f32)) / (4.0 * b * t / (m as f32));
78            let latter_f = PI * (1.0 - 16.0 * (b * t / (m as f32)).powi(2));
79
80            let mut val: f32 = factor * (cos_calc + sin_calc) / latter_f;
81            // primitive check if NaN:
82            // happens when t = (+/-) m/(4b)
83            // we shall not divide by 0
84            if val.is_nan() {
85                val = b / f32::sqrt(2.0 * m as f32)
86                    * ((1.0 + 2.0 / PI) * f32::sin(PI / (4.0 * b))
87                        + (1.0 - 2.0 / PI) * f32::cos(PI / (4.0 * b)));
88            }
89            srrc.push(val as f32);
90        }
91
92        // calculate sinc for t=0, i.e.: t=offset, since right shifted
93        srrc.push(1.0 / f32::sqrt(m as f32) * ((1.0 - b) + 4.0 * b / PI));
94
95        // create second half of signal by reflection symmetry
96        // mirrors signal at n=offset
97        for i in 0..offset {
98            srrc.push(srrc[offset - i - 1]);
99        }
100
101        Filter { coeff: srrc }
102    }
103}
104
105#[cfg(test)]
106mod test {
107    use super::*;
108
109    #[test]
110    fn test_srrc_len() {
111        // Size schould be 1 + 4*2*2
112        let srrc = Filter::create_srrc(4, 0.5, 2);
113        assert_eq!(srrc.len(), 17);
114    }
115
116    #[test]
117    fn test_srrc_symmetry() {
118        let srrc = Filter::create_srrc(4, 0.5, 2);
119
120        // split srrc filter in middle and compare equality of sample i and N-i
121        // e.g. 0 and len-0-1
122        let len = srrc.len();
123        for i in 0..len / 2 - 1 {
124            assert_eq!(srrc.coeff[i], srrc.coeff[len - i - 1]);
125        }
126    }
127
128    #[test]
129    fn test_srrc_values() {
130        let srrc = Filter::create_srrc(4, 0.5, 1);
131        // these values are from octave: see Software Radio Design book:
132        // -> included matlab file srrc.m
133        // only really need to validate one side of the pulse (last sample is middle of srrc
134        let srrc_mat = vec![-0.010105, 0.0030315, 0.042441, -0.1061, 1.1366];
135
136        // diff/error between octave and my implementation should not exceed this
137        let err = 0.0001;
138        for i in 0..srrc_mat.len() {
139            let err_sqrd = (srrc.coeff[i] - srrc_mat[i]).abs();
140            assert!(err_sqrd < err)
141        }
142    }
143
144    #[test]
145    fn test_filter() {
146        // N - M + 1 -> 7 - 3 + 1 = 5
147        let filter = Filter::new(vec![1., 0., 1.]);
148        let sig = vec![1., 2., 3., 4., 5., 6., 7.];
149
150        let expected_result = vec![4., 6., 8., 10., 12.];
151        assert_eq!(expected_result, filter.filter_windowed(&sig));
152    }
153
154    #[test]
155    fn create_moving_average() {
156        let vec = vec![0.2, 0.2, 0.2, 0.2, 0.2];
157        let filter_ma_5 = Filter::new(vec);
158        assert_eq!(filter_ma_5, Filter::create_moving_average(5));
159
160        let vec = vec![0.25, 0.25, 0.25, 0.25];
161        let filter_ma_4 = Filter::new(vec);
162        assert_eq!(filter_ma_4, Filter::create_moving_average(4));
163    }
164}