Skip to main content

rustradio/
fir.rs

1//! Finite impulse response filter.
2//!
3//! If using many taps, [`FftFilter`](crate::blocks::FftFilter) probably has
4//! better performance.
5/*
6 * TODO:
7 * * Only handles case where input, output, and tap type are all the same.
8 */
9use crate::block::{Block, BlockRet};
10use crate::stream::{ReadStream, WriteStream};
11use crate::window::{Window, WindowType};
12use crate::{Complex, Float, Result, Sample};
13
14/// Finite impulse response filter.
15pub struct Fir<T> {
16    taps: Vec<T>,
17}
18
19#[cfg(all(
20    target_feature = "avx",
21    target_feature = "sse3",
22    target_feature = "sse"
23))]
24#[allow(unreachable_code)]
25fn sum_product_avx(vec1: &[f32], vec2: &[f32]) -> f32 {
26    // SAFETY: Pointer arithmetic "should be fine". And as for instruction availability, that could
27    // be checked by the macro above.
28    unsafe {
29        use core::arch::x86_64::*;
30        assert_eq!(vec1.len(), vec2.len());
31        let len = vec1.len() - vec1.len() % 8;
32
33        // AVX.
34        let mut sum = _mm256_setzero_ps(); // Initialize sum vector to zeros.
35
36        for i in (0..len).step_by(8) {
37            // AVX.
38            let a = _mm256_loadu_ps(vec1.as_ptr().add(i));
39            let b = _mm256_loadu_ps(vec2.as_ptr().add(i));
40
41            // Multiply and accumulate.
42            // AVX.
43            let prod = _mm256_mul_ps(a, b);
44            sum = _mm256_add_ps(sum, prod);
45        }
46
47        // Split.
48        // AVX.
49        let low = _mm256_extractf128_ps(sum, 0);
50        let high = _mm256_extractf128_ps(sum, 1);
51
52        // Compact step 1 => 4 floats.
53        // SSE3.
54        let m128 = _mm_hadd_ps(low, high);
55
56        // Compact step 2 => 2 floats.
57        // SSE3.
58        let m128 = _mm_hadd_ps(m128, low);
59
60        // Compact step 3 => 1 floats.
61        // SSE3.
62        let m128 = _mm_hadd_ps(m128, low);
63        // SSE.
64        let partial = _mm_cvtss_f32(m128);
65        let skip = vec1.len() - vec1.len() % 8;
66        vec1[skip..]
67            .iter()
68            .zip(vec2[skip..].iter())
69            .fold(partial, |acc, (&f, &x)| acc + x * f)
70    }
71}
72
73impl Fir<Float> {
74    /// Run filter once, creating one sample from the taps and an
75    /// equal number of input samples.
76    #[must_use]
77    pub fn filter_float(&self, input: &[Float]) -> Float {
78        // AVX is faster, when available.
79        #[cfg(all(
80            target_feature = "avx",
81            target_feature = "sse3",
82            target_feature = "sse"
83        ))]
84        return sum_product_avx(&self.taps, input);
85        // Second fastest is generic simd.
86        #[cfg(feature = "simd")]
87        #[allow(unreachable_code)]
88        {
89            use std::simd::num::SimdFloat;
90            type Batch = std::simd::f32x8;
91
92            let batch_n = 8;
93            // How will this work if Float is f64?
94            let partial = input
95                .chunks_exact(batch_n)
96                .zip(self.taps.chunks_exact(batch_n))
97                .map(|(a, b)| Batch::from_slice(a) * Batch::from_slice(b))
98                .fold(Batch::splat(0.0), |acc, x| acc + x)
99                .reduce_sum();
100            // Maybe even faster if doing a second round with f32x4.
101            let skip = self.taps.len() - self.taps.len() % batch_n;
102            return input[skip..]
103                .iter()
104                .zip(self.taps[skip..].iter())
105                .fold(partial, |acc, (&f, &x)| acc + x * f);
106        }
107        #[allow(unreachable_code)]
108        self.filter(input)
109    }
110}
111
112impl<T> Fir<T>
113where
114    T: Sample + std::ops::Mul<T, Output = T> + std::ops::Add<T, Output = T>,
115{
116    /// Create new Fir.
117    #[must_use]
118    pub fn new(taps: &[T]) -> Self {
119        Self {
120            taps: taps.iter().copied().rev().collect(),
121        }
122    }
123    /// Run filter once, creating one sample from the taps and an
124    /// equal number of input samples.
125    #[must_use]
126    pub fn filter(&self, input: &[T]) -> T {
127        assert!(
128            input.len() >= self.taps.len(),
129            "input {} < taps {}",
130            input.len(),
131            self.taps.len()
132        );
133        input
134            .iter()
135            .zip(self.taps.iter())
136            .fold(T::default(), |acc, (&f, &x)| acc + x * f)
137    }
138
139    /// Call `filter()` multiple times, across an input range.
140    #[must_use]
141    pub fn filter_n(&self, input: &[T], deci: usize) -> Vec<T> {
142        let n = input.len() - self.taps.len();
143        (0..=n)
144            .step_by(deci)
145            .map(|i| self.filter(&input[i..]))
146            .collect()
147    }
148
149    /// Like `filter_n`, but avoids a copy when there's a destination in mind.
150    pub fn filter_n_inplace(&self, input: &[T], deci: usize, out: &mut [T]) {
151        out.iter_mut()
152            .enumerate()
153            .for_each(|(i, o)| *o = self.filter(&input[(i * deci)..]));
154    }
155}
156
157/// Builder for a FIR filter block.
158///
159/// A builder is needed to create a decimating FIR filter block.
160pub struct FirFilterBuilder<T> {
161    taps: Vec<T>,
162    deci: usize,
163}
164
165impl<T> FirFilterBuilder<T>
166where
167    T: Sample + std::ops::Mul<T, Output = T> + std::ops::Add<T, Output = T>,
168{
169    /// Set the decimation to the given value.
170    ///
171    /// The default is 1, meaning no decimation.
172    #[must_use]
173    pub fn deci(mut self, deci: usize) -> Self {
174        self.deci = deci;
175        self
176    }
177
178    /// Build a `FirFilter` with the provided settings.
179    #[must_use]
180    pub fn build(self, src: ReadStream<T>) -> (FirFilter<T>, ReadStream<T>) {
181        let (mut block, stream) = FirFilter::new(src, &self.taps);
182        block.deci = self.deci;
183        (block, stream)
184    }
185}
186
187/// Finite impulse response filter block.
188#[derive(rustradio_macros::Block)]
189#[rustradio(crate)]
190pub struct FirFilter<T: Sample> {
191    fir: Fir<T>,
192    ntaps: usize,
193    deci: usize,
194    #[rustradio(in)]
195    src: ReadStream<T>,
196    #[rustradio(out)]
197    dst: WriteStream<T>,
198}
199
200impl<T> FirFilter<T>
201where
202    T: Sample + std::ops::Mul<T, Output = T> + std::ops::Add<T, Output = T>,
203{
204    /// Create new `FirFilterBuilder`, with the supplied taps.
205    pub fn builder(taps: &[T]) -> FirFilterBuilder<T> {
206        FirFilterBuilder {
207            taps: taps.to_vec(),
208            deci: 1,
209        }
210    }
211    /// Create Fir block given taps.
212    pub fn new(src: ReadStream<T>, taps: &[T]) -> (Self, ReadStream<T>) {
213        let (dst, dr) = crate::stream::new_stream();
214        (
215            Self {
216                src,
217                dst,
218                ntaps: taps.len(),
219                deci: 1,
220                fir: Fir::new(taps),
221            },
222            dr,
223        )
224    }
225}
226
227impl<T> Block for FirFilter<T>
228where
229    T: Sample + std::ops::Mul<T, Output = T> + std::ops::Add<T, Output = T>,
230{
231    fn work(&mut self) -> Result<BlockRet<'_>> {
232        let (input, mut tags) = self.src.read_buf()?;
233
234        // Get number of input samples we intend to consume.
235        let n = {
236            // Carefully avoid underflow.
237            let absolute_minimum = self.ntaps + self.deci - 1;
238            if input.len() < absolute_minimum {
239                return Ok(BlockRet::WaitForStream(&self.src, absolute_minimum));
240            }
241            self.deci * ((input.len() - self.ntaps + 1) / self.deci)
242        };
243        assert_ne!(n, 0);
244
245        // To consume `n`, we may need more input samples than that.
246        let need = n + self.ntaps - 1;
247        assert!(input.len() >= need, "need {need}, have {}", input.len());
248
249        // Output must have room for at least one sample.
250        let mut out = self.dst.write_buf()?;
251        let need_out = 1;
252        if out.len() < need_out {
253            return Ok(BlockRet::WaitForStream(&self.dst, need_out));
254        }
255
256        // Cap by output capacity.
257        let n = std::cmp::min(n, out.len() * self.deci);
258
259        // Final `n` (samples to consume) calculated. Sanity check it.
260        assert_eq!(n % self.deci, 0);
261        assert_ne!(n, 0, "input: {} out: {}", input.len(), out.len());
262
263        // Run the FIR.
264        let out_n = n / self.deci;
265        self.fir
266            .filter_n_inplace(&input.slice()[..need], self.deci, &mut out.slice()[..out_n]);
267
268        // Sanity check the generated output.
269        assert!(out_n <= out.len());
270
271        input.consume(n);
272        if self.deci == 1 {
273            out.produce(out_n, &tags);
274        } else {
275            for t in &mut tags {
276                t.set_pos(t.pos() / self.deci);
277            }
278            out.produce(out_n, &tags);
279        }
280        // While we could keep track of which stream is the constraining factor,
281        // the code is simpler if work() is just called again, and the right
282        // WaitForStream is returned above instead.
283        Ok(BlockRet::Again)
284    }
285}
286
287/// Create a multiband filter.
288///
289/// TODO: this is untested.
290#[must_use]
291pub fn multiband(bands: &[(Float, Float)], taps: usize, window: &Window) -> Option<Vec<Complex>> {
292    use rustfft::FftPlanner;
293
294    if taps != window.0.len() {
295        return None;
296    }
297
298    let mut ideal = vec![Complex::new(0.0, 0.0); taps];
299    let scale = (taps as Float) / 2.0;
300    for (low, high) in bands {
301        let a = (low * scale).floor() as usize;
302        let b = (high * scale).ceil() as usize;
303        for n in a..b {
304            ideal[n] = Complex::new(1.0, 0.0);
305            ideal[taps - n - 1] = Complex::new(1.0, 0.0);
306        }
307    }
308    let fft_size = taps;
309    let mut planner = FftPlanner::new();
310    let ifft = planner.plan_fft_inverse(fft_size);
311    ifft.process(&mut ideal);
312    ideal.rotate_right(taps / 2);
313    let scale = (fft_size as Float).sqrt();
314    Some(
315        ideal
316            .into_iter()
317            .enumerate()
318            .map(|(n, v)| v * window.0[n] / Complex::new(scale, 0.0))
319            .collect(),
320    )
321}
322
323/// Create taps for a low pass filter as complex taps.
324#[must_use]
325pub fn low_pass_complex(
326    samp_rate: Float,
327    cutoff: Float,
328    twidth: Float,
329    window_type: &WindowType,
330) -> Vec<Complex> {
331    low_pass(samp_rate, cutoff, twidth, window_type)
332        .into_iter()
333        .map(|t| Complex::new(t, 0.0))
334        .collect()
335}
336
337fn compute_ntaps(samp_rate: Float, twidth: Float, window_type: &WindowType) -> usize {
338    let a = window_type.max_attenuation();
339    let t = (a * samp_rate / (22.0 * twidth)) as usize;
340    if (t & 1) == 0 { t + 1 } else { t }
341}
342
343/// Create taps for a low pass filter.
344///
345/// TODO: this could be faster if we supported filtering a Complex by a Float.
346/// A low pass filter doesn't actually need complex taps.
347#[must_use]
348pub fn low_pass(
349    samp_rate: Float,
350    cutoff: Float,
351    twidth: Float,
352    window_type: &WindowType,
353) -> Vec<Float> {
354    let pi = std::f64::consts::PI as Float;
355    let ntaps = compute_ntaps(samp_rate, twidth, window_type);
356    let window = window_type.make_window(ntaps);
357    let m = (ntaps - 1) / 2;
358    let fwt0 = 2.0 * pi * cutoff / samp_rate;
359    let taps: Vec<_> = window
360        .0
361        .iter()
362        .enumerate()
363        .map(|(nm, win)| {
364            let n = nm as i64 - m as i64;
365            let nf = n as Float;
366            if n == 0 {
367                fwt0 / pi * win
368            } else {
369                ((nf * fwt0).sin() / (nf * pi)) * win
370            }
371        })
372        .collect();
373    let gain = {
374        let gain: Float = 1.0;
375        let mut fmax = taps[m];
376        for n in 1..=m {
377            fmax += 2.0 * taps[n + m];
378        }
379        gain / fmax
380    };
381    taps.into_iter().map(|t| t * gain).collect()
382}
383
384/// Generate hilbert transformer filter.
385#[must_use]
386pub fn hilbert(window: &Window) -> Vec<Float> {
387    let ntaps = window.0.len();
388    let mid = (ntaps - 1) / 2;
389    let mut gain = 0.0;
390    let mut taps = vec![0.0; ntaps];
391    for i in 1..=mid {
392        if i & 1 == 1 {
393            let x = 1.0 / (i as Float);
394            taps[mid + i] = x * window.0[mid + i];
395            taps[mid - i] = -x * window.0[mid - i];
396            gain = taps[mid + i] - gain;
397        } else {
398            taps[mid + i] = 0.0;
399            taps[mid - i] = 0.0;
400        }
401    }
402    let gain = 1.0 / (2.0 * gain.abs());
403    taps.iter().map(|e| gain * *e).collect()
404}
405
406#[cfg(test)]
407#[cfg_attr(coverage_nightly, coverage(off))]
408mod tests {
409    use super::*;
410    use crate::Repeat;
411    use crate::blocks::VectorSource;
412    use crate::stream::{Tag, TagValue};
413    use crate::tests::assert_almost_equal_complex;
414
415    #[test]
416    fn test_identity() -> Result<()> {
417        let input = vec![
418            Complex::new(1.0, 0.0),
419            Complex::new(2.0, 0.0),
420            Complex::new(3.0, 0.2),
421            Complex::new(4.1, 0.0),
422            Complex::new(5.0, 0.0),
423            Complex::new(6.0, 0.2),
424        ];
425        let taps = vec![Complex::new(1.0, 0.0)];
426        for deci in 1..=(3 * input.len()) {
427            let (mut src, src_out) = VectorSource::builder(input.clone())
428                .repeat(Repeat::finite(2))
429                .build()?;
430            assert!(matches![src.work()?, BlockRet::Again]);
431            assert!(matches![src.work()?, BlockRet::EOF]);
432
433            eprintln!("Testing identity with decimation {deci}");
434            let (mut b, os) = FirFilter::builder(&taps).deci(deci).build(src_out);
435            if deci <= 2 * input.len() {
436                assert!(matches![b.work()?, BlockRet::Again]);
437            }
438            assert!(matches![b.work()?, BlockRet::WaitForStream(_, _)]);
439            let (res, tags) = os.read_buf()?;
440            let max = 2 * input.len() / deci;
441            if !res.is_empty() {
442                assert_eq!(
443                    &tags,
444                    &[
445                        Tag::new(0, "VectorSource::start", TagValue::Bool(true)),
446                        Tag::new(0, "VectorSource::repeat", TagValue::U64(0)),
447                        Tag::new(0, "VectorSource::first", TagValue::Bool(true)),
448                        Tag::new(6 / deci, "VectorSource::start", TagValue::Bool(true)),
449                        Tag::new(6 / deci, "VectorSource::repeat", TagValue::U64(1)),
450                    ]
451                );
452            }
453            assert_almost_equal_complex(
454                res.slice(),
455                &input
456                    .iter()
457                    .chain(input.iter())
458                    .copied()
459                    .step_by(deci)
460                    .take(max)
461                    .collect::<Vec<_>>(),
462            );
463        }
464        Ok(())
465    }
466
467    #[test]
468    fn test_invert() -> Result<()> {
469        let input = vec![
470            Complex::new(1.0, 0.0),
471            Complex::new(2.0, 0.0),
472            Complex::new(3.0, 0.2),
473            Complex::new(4.1, 0.0),
474            Complex::new(5.0, 0.0),
475            Complex::new(6.0, 0.2),
476        ];
477        let taps = vec![Complex::new(-1.0, 0.0)];
478        for deci in 1..=(input.len() + 1) {
479            let (mut src, src_out) = VectorSource::new(input.clone());
480            src.work()?;
481
482            eprintln!("Testing identity with decimation {deci}");
483            let (mut b, os) = FirFilter::builder(&taps).deci(deci).build(src_out);
484            if deci <= input.len() {
485                assert!(matches![b.work()?, BlockRet::Again]);
486            }
487            assert!(matches![b.work()?, BlockRet::WaitForStream(_, _)]);
488            let (res, _) = os.read_buf()?;
489            let max = input.len() / deci;
490            assert_almost_equal_complex(
491                res.slice(),
492                &input
493                    .iter()
494                    .copied()
495                    .step_by(deci)
496                    .take(max)
497                    .map(|v| -v)
498                    .collect::<Vec<_>>(),
499            );
500        }
501        Ok(())
502    }
503
504    #[test]
505    fn moving_avg() -> Result<()> {
506        let input = vec![
507            Complex::new(1.0, 0.0),
508            Complex::new(2.0, 0.0),
509            Complex::new(3.0, 0.2),
510            Complex::new(4.1, 0.0),
511            Complex::new(5.0, 0.0),
512            Complex::new(6.0, 0.2),
513        ];
514        let taps = vec![Complex::new(0.5, 0.0), Complex::new(0.5, 0.0)];
515        for deci in 1..=(input.len() + 1) {
516            let (mut src, src_out) = VectorSource::new(input.clone());
517            src.work()?;
518
519            eprintln!("Testing identity with decimation {deci}");
520            let (mut b, os) = FirFilter::builder(&taps).deci(deci).build(src_out);
521            if deci < input.len() {
522                assert!(matches![b.work()?, BlockRet::Again]);
523            }
524            assert!(matches![b.work()?, BlockRet::WaitForStream(_, _)]);
525            let (res, _) = os.read_buf()?;
526            let max = (input.len() - 1) / deci;
527            assert_almost_equal_complex(
528                res.slice(),
529                &[
530                    Complex::new(1.5, 0.0),
531                    Complex::new(2.5, 0.1),
532                    Complex::new(3.55, 0.1),
533                    Complex::new(4.55, 0.0),
534                    Complex::new(5.5, 0.1),
535                ]
536                .into_iter()
537                .step_by(deci)
538                .take(max)
539                .collect::<Vec<_>>(),
540            );
541        }
542        Ok(())
543    }
544
545    #[test]
546    fn test_complex() {
547        let input = vec![
548            Complex::new(1.0, 0.0),
549            Complex::new(2.0, 0.0),
550            Complex::new(3.0, 0.2),
551            Complex::new(4.1, 0.0),
552            Complex::new(5.0, 0.0),
553            Complex::new(6.0, 0.2),
554        ];
555        let taps = vec![
556            Complex::new(0.1, 0.0),
557            Complex::new(1.0, 0.0),
558            Complex::new(0.0, 0.2),
559        ];
560        let filter = Fir::new(&taps);
561        assert_almost_equal_complex(
562            &filter.filter_n(&input, 1),
563            &[
564                Complex::new(2.3, 0.22),
565                Complex::new(3.41, 0.6),
566                Complex::new(4.56, 0.6),
567                Complex::new(5.6, 0.84),
568            ],
569        );
570        assert_almost_equal_complex(
571            &filter.filter_n(&input, 2),
572            &[Complex::new(2.3, 0.22), Complex::new(4.56, 0.6)],
573        );
574    }
575
576    #[test]
577    fn test_filter_generator() {
578        let taps = low_pass_complex(10000.0, 1000.0, 1000.0, &WindowType::Hamming);
579        assert_eq!(taps.len(), 25);
580        assert_almost_equal_complex(
581            &taps,
582            &[
583                Complex::new(0.002010403, 0.0),
584                Complex::new(0.0016210203, 0.0),
585                Complex::new(7.851862e-10, 0.0),
586                Complex::new(-0.0044467063, 0.0),
587                Complex::new(-0.011685465, 0.0),
588                Complex::new(-0.018134259, 0.0),
589                Complex::new(-0.016773716, 0.0),
590                Complex::new(-3.6538055e-9, 0.0),
591                Complex::new(0.0358771, 0.0),
592                Complex::new(0.08697697, 0.0),
593                Complex::new(0.14148787, 0.0),
594                Complex::new(0.18345332, 0.0),
595                Complex::new(0.19922684, 0.0),
596                Complex::new(0.1834533, 0.0),
597                Complex::new(0.14148785, 0.0),
598                Complex::new(0.08697697, 0.0),
599                Complex::new(0.035877097, 0.0),
600                Complex::new(-3.6538053e-9, 0.0),
601                Complex::new(-0.016773716, 0.0),
602                Complex::new(-0.018134257, 0.0),
603                Complex::new(-0.011685458, 0.0),
604                Complex::new(-0.0044467044, 0.0),
605                Complex::new(7.851859e-10, 0.0),
606                Complex::new(0.0016210207, 0.0),
607                Complex::new(0.002010403, 0.0),
608            ],
609        );
610    }
611}