1use crate::block::{Block, BlockRet};
10use crate::stream::{ReadStream, WriteStream};
11use crate::window::{Window, WindowType};
12use crate::{Complex, Float, Result, Sample};
13
14pub struct Fir<T> {
16 taps: Vec<T>,
17}
18
19#[cfg(all(
20 target_feature = "avx",
21 target_feature = "sse3",
22 target_feature = "sse"
23))]
24#[allow(unreachable_code)]
25fn sum_product_avx(vec1: &[f32], vec2: &[f32]) -> f32 {
26 unsafe {
29 use core::arch::x86_64::*;
30 assert_eq!(vec1.len(), vec2.len());
31 let len = vec1.len() - vec1.len() % 8;
32
33 let mut sum = _mm256_setzero_ps(); for i in (0..len).step_by(8) {
37 let a = _mm256_loadu_ps(vec1.as_ptr().add(i));
39 let b = _mm256_loadu_ps(vec2.as_ptr().add(i));
40
41 let prod = _mm256_mul_ps(a, b);
44 sum = _mm256_add_ps(sum, prod);
45 }
46
47 let low = _mm256_extractf128_ps(sum, 0);
50 let high = _mm256_extractf128_ps(sum, 1);
51
52 let m128 = _mm_hadd_ps(low, high);
55
56 let m128 = _mm_hadd_ps(m128, low);
59
60 let m128 = _mm_hadd_ps(m128, low);
63 let partial = _mm_cvtss_f32(m128);
65 let skip = vec1.len() - vec1.len() % 8;
66 vec1[skip..]
67 .iter()
68 .zip(vec2[skip..].iter())
69 .fold(partial, |acc, (&f, &x)| acc + x * f)
70 }
71}
72
73impl Fir<Float> {
74 #[must_use]
77 pub fn filter_float(&self, input: &[Float]) -> Float {
78 #[cfg(all(
80 target_feature = "avx",
81 target_feature = "sse3",
82 target_feature = "sse"
83 ))]
84 return sum_product_avx(&self.taps, input);
85 #[cfg(feature = "simd")]
87 #[allow(unreachable_code)]
88 {
89 use std::simd::num::SimdFloat;
90 type Batch = std::simd::f32x8;
91
92 let batch_n = 8;
93 let partial = input
95 .chunks_exact(batch_n)
96 .zip(self.taps.chunks_exact(batch_n))
97 .map(|(a, b)| Batch::from_slice(a) * Batch::from_slice(b))
98 .fold(Batch::splat(0.0), |acc, x| acc + x)
99 .reduce_sum();
100 let skip = self.taps.len() - self.taps.len() % batch_n;
102 return input[skip..]
103 .iter()
104 .zip(self.taps[skip..].iter())
105 .fold(partial, |acc, (&f, &x)| acc + x * f);
106 }
107 #[allow(unreachable_code)]
108 self.filter(input)
109 }
110}
111
112impl<T> Fir<T>
113where
114 T: Sample + std::ops::Mul<T, Output = T> + std::ops::Add<T, Output = T>,
115{
116 #[must_use]
118 pub fn new(taps: &[T]) -> Self {
119 Self {
120 taps: taps.iter().copied().rev().collect(),
121 }
122 }
123 #[must_use]
126 pub fn filter(&self, input: &[T]) -> T {
127 assert!(
128 input.len() >= self.taps.len(),
129 "input {} < taps {}",
130 input.len(),
131 self.taps.len()
132 );
133 input
134 .iter()
135 .zip(self.taps.iter())
136 .fold(T::default(), |acc, (&f, &x)| acc + x * f)
137 }
138
139 #[must_use]
141 pub fn filter_n(&self, input: &[T], deci: usize) -> Vec<T> {
142 let n = input.len() - self.taps.len();
143 (0..=n)
144 .step_by(deci)
145 .map(|i| self.filter(&input[i..]))
146 .collect()
147 }
148
149 pub fn filter_n_inplace(&self, input: &[T], deci: usize, out: &mut [T]) {
151 out.iter_mut()
152 .enumerate()
153 .for_each(|(i, o)| *o = self.filter(&input[(i * deci)..]));
154 }
155}
156
157pub struct FirFilterBuilder<T> {
161 taps: Vec<T>,
162 deci: usize,
163}
164
165impl<T> FirFilterBuilder<T>
166where
167 T: Sample + std::ops::Mul<T, Output = T> + std::ops::Add<T, Output = T>,
168{
169 #[must_use]
173 pub fn deci(mut self, deci: usize) -> Self {
174 self.deci = deci;
175 self
176 }
177
178 #[must_use]
180 pub fn build(self, src: ReadStream<T>) -> (FirFilter<T>, ReadStream<T>) {
181 let (mut block, stream) = FirFilter::new(src, &self.taps);
182 block.deci = self.deci;
183 (block, stream)
184 }
185}
186
187#[derive(rustradio_macros::Block)]
189#[rustradio(crate)]
190pub struct FirFilter<T: Sample> {
191 fir: Fir<T>,
192 ntaps: usize,
193 deci: usize,
194 #[rustradio(in)]
195 src: ReadStream<T>,
196 #[rustradio(out)]
197 dst: WriteStream<T>,
198}
199
200impl<T> FirFilter<T>
201where
202 T: Sample + std::ops::Mul<T, Output = T> + std::ops::Add<T, Output = T>,
203{
204 pub fn builder(taps: &[T]) -> FirFilterBuilder<T> {
206 FirFilterBuilder {
207 taps: taps.to_vec(),
208 deci: 1,
209 }
210 }
211 pub fn new(src: ReadStream<T>, taps: &[T]) -> (Self, ReadStream<T>) {
213 let (dst, dr) = crate::stream::new_stream();
214 (
215 Self {
216 src,
217 dst,
218 ntaps: taps.len(),
219 deci: 1,
220 fir: Fir::new(taps),
221 },
222 dr,
223 )
224 }
225}
226
227impl<T> Block for FirFilter<T>
228where
229 T: Sample + std::ops::Mul<T, Output = T> + std::ops::Add<T, Output = T>,
230{
231 fn work(&mut self) -> Result<BlockRet<'_>> {
232 let (input, mut tags) = self.src.read_buf()?;
233
234 let n = {
236 let absolute_minimum = self.ntaps + self.deci - 1;
238 if input.len() < absolute_minimum {
239 return Ok(BlockRet::WaitForStream(&self.src, absolute_minimum));
240 }
241 self.deci * ((input.len() - self.ntaps + 1) / self.deci)
242 };
243 assert_ne!(n, 0);
244
245 let need = n + self.ntaps - 1;
247 assert!(input.len() >= need, "need {need}, have {}", input.len());
248
249 let mut out = self.dst.write_buf()?;
251 let need_out = 1;
252 if out.len() < need_out {
253 return Ok(BlockRet::WaitForStream(&self.dst, need_out));
254 }
255
256 let n = std::cmp::min(n, out.len() * self.deci);
258
259 assert_eq!(n % self.deci, 0);
261 assert_ne!(n, 0, "input: {} out: {}", input.len(), out.len());
262
263 let out_n = n / self.deci;
265 self.fir
266 .filter_n_inplace(&input.slice()[..need], self.deci, &mut out.slice()[..out_n]);
267
268 assert!(out_n <= out.len());
270
271 input.consume(n);
272 if self.deci == 1 {
273 out.produce(out_n, &tags);
274 } else {
275 for t in &mut tags {
276 t.set_pos(t.pos() / self.deci);
277 }
278 out.produce(out_n, &tags);
279 }
280 Ok(BlockRet::Again)
284 }
285}
286
287#[must_use]
291pub fn multiband(bands: &[(Float, Float)], taps: usize, window: &Window) -> Option<Vec<Complex>> {
292 use rustfft::FftPlanner;
293
294 if taps != window.0.len() {
295 return None;
296 }
297
298 let mut ideal = vec![Complex::new(0.0, 0.0); taps];
299 let scale = (taps as Float) / 2.0;
300 for (low, high) in bands {
301 let a = (low * scale).floor() as usize;
302 let b = (high * scale).ceil() as usize;
303 for n in a..b {
304 ideal[n] = Complex::new(1.0, 0.0);
305 ideal[taps - n - 1] = Complex::new(1.0, 0.0);
306 }
307 }
308 let fft_size = taps;
309 let mut planner = FftPlanner::new();
310 let ifft = planner.plan_fft_inverse(fft_size);
311 ifft.process(&mut ideal);
312 ideal.rotate_right(taps / 2);
313 let scale = (fft_size as Float).sqrt();
314 Some(
315 ideal
316 .into_iter()
317 .enumerate()
318 .map(|(n, v)| v * window.0[n] / Complex::new(scale, 0.0))
319 .collect(),
320 )
321}
322
323#[must_use]
325pub fn low_pass_complex(
326 samp_rate: Float,
327 cutoff: Float,
328 twidth: Float,
329 window_type: &WindowType,
330) -> Vec<Complex> {
331 low_pass(samp_rate, cutoff, twidth, window_type)
332 .into_iter()
333 .map(|t| Complex::new(t, 0.0))
334 .collect()
335}
336
337fn compute_ntaps(samp_rate: Float, twidth: Float, window_type: &WindowType) -> usize {
338 let a = window_type.max_attenuation();
339 let t = (a * samp_rate / (22.0 * twidth)) as usize;
340 if (t & 1) == 0 { t + 1 } else { t }
341}
342
343#[must_use]
348pub fn low_pass(
349 samp_rate: Float,
350 cutoff: Float,
351 twidth: Float,
352 window_type: &WindowType,
353) -> Vec<Float> {
354 let pi = std::f64::consts::PI as Float;
355 let ntaps = compute_ntaps(samp_rate, twidth, window_type);
356 let window = window_type.make_window(ntaps);
357 let m = (ntaps - 1) / 2;
358 let fwt0 = 2.0 * pi * cutoff / samp_rate;
359 let taps: Vec<_> = window
360 .0
361 .iter()
362 .enumerate()
363 .map(|(nm, win)| {
364 let n = nm as i64 - m as i64;
365 let nf = n as Float;
366 if n == 0 {
367 fwt0 / pi * win
368 } else {
369 ((nf * fwt0).sin() / (nf * pi)) * win
370 }
371 })
372 .collect();
373 let gain = {
374 let gain: Float = 1.0;
375 let mut fmax = taps[m];
376 for n in 1..=m {
377 fmax += 2.0 * taps[n + m];
378 }
379 gain / fmax
380 };
381 taps.into_iter().map(|t| t * gain).collect()
382}
383
384#[must_use]
386pub fn hilbert(window: &Window) -> Vec<Float> {
387 let ntaps = window.0.len();
388 let mid = (ntaps - 1) / 2;
389 let mut gain = 0.0;
390 let mut taps = vec![0.0; ntaps];
391 for i in 1..=mid {
392 if i & 1 == 1 {
393 let x = 1.0 / (i as Float);
394 taps[mid + i] = x * window.0[mid + i];
395 taps[mid - i] = -x * window.0[mid - i];
396 gain = taps[mid + i] - gain;
397 } else {
398 taps[mid + i] = 0.0;
399 taps[mid - i] = 0.0;
400 }
401 }
402 let gain = 1.0 / (2.0 * gain.abs());
403 taps.iter().map(|e| gain * *e).collect()
404}
405
406#[cfg(test)]
407#[cfg_attr(coverage_nightly, coverage(off))]
408mod tests {
409 use super::*;
410 use crate::Repeat;
411 use crate::blocks::VectorSource;
412 use crate::stream::{Tag, TagValue};
413 use crate::tests::assert_almost_equal_complex;
414
415 #[test]
416 fn test_identity() -> Result<()> {
417 let input = vec![
418 Complex::new(1.0, 0.0),
419 Complex::new(2.0, 0.0),
420 Complex::new(3.0, 0.2),
421 Complex::new(4.1, 0.0),
422 Complex::new(5.0, 0.0),
423 Complex::new(6.0, 0.2),
424 ];
425 let taps = vec![Complex::new(1.0, 0.0)];
426 for deci in 1..=(3 * input.len()) {
427 let (mut src, src_out) = VectorSource::builder(input.clone())
428 .repeat(Repeat::finite(2))
429 .build()?;
430 assert!(matches![src.work()?, BlockRet::Again]);
431 assert!(matches![src.work()?, BlockRet::EOF]);
432
433 eprintln!("Testing identity with decimation {deci}");
434 let (mut b, os) = FirFilter::builder(&taps).deci(deci).build(src_out);
435 if deci <= 2 * input.len() {
436 assert!(matches![b.work()?, BlockRet::Again]);
437 }
438 assert!(matches![b.work()?, BlockRet::WaitForStream(_, _)]);
439 let (res, tags) = os.read_buf()?;
440 let max = 2 * input.len() / deci;
441 if !res.is_empty() {
442 assert_eq!(
443 &tags,
444 &[
445 Tag::new(0, "VectorSource::start", TagValue::Bool(true)),
446 Tag::new(0, "VectorSource::repeat", TagValue::U64(0)),
447 Tag::new(0, "VectorSource::first", TagValue::Bool(true)),
448 Tag::new(6 / deci, "VectorSource::start", TagValue::Bool(true)),
449 Tag::new(6 / deci, "VectorSource::repeat", TagValue::U64(1)),
450 ]
451 );
452 }
453 assert_almost_equal_complex(
454 res.slice(),
455 &input
456 .iter()
457 .chain(input.iter())
458 .copied()
459 .step_by(deci)
460 .take(max)
461 .collect::<Vec<_>>(),
462 );
463 }
464 Ok(())
465 }
466
467 #[test]
468 fn test_invert() -> Result<()> {
469 let input = vec![
470 Complex::new(1.0, 0.0),
471 Complex::new(2.0, 0.0),
472 Complex::new(3.0, 0.2),
473 Complex::new(4.1, 0.0),
474 Complex::new(5.0, 0.0),
475 Complex::new(6.0, 0.2),
476 ];
477 let taps = vec![Complex::new(-1.0, 0.0)];
478 for deci in 1..=(input.len() + 1) {
479 let (mut src, src_out) = VectorSource::new(input.clone());
480 src.work()?;
481
482 eprintln!("Testing identity with decimation {deci}");
483 let (mut b, os) = FirFilter::builder(&taps).deci(deci).build(src_out);
484 if deci <= input.len() {
485 assert!(matches![b.work()?, BlockRet::Again]);
486 }
487 assert!(matches![b.work()?, BlockRet::WaitForStream(_, _)]);
488 let (res, _) = os.read_buf()?;
489 let max = input.len() / deci;
490 assert_almost_equal_complex(
491 res.slice(),
492 &input
493 .iter()
494 .copied()
495 .step_by(deci)
496 .take(max)
497 .map(|v| -v)
498 .collect::<Vec<_>>(),
499 );
500 }
501 Ok(())
502 }
503
504 #[test]
505 fn moving_avg() -> Result<()> {
506 let input = vec![
507 Complex::new(1.0, 0.0),
508 Complex::new(2.0, 0.0),
509 Complex::new(3.0, 0.2),
510 Complex::new(4.1, 0.0),
511 Complex::new(5.0, 0.0),
512 Complex::new(6.0, 0.2),
513 ];
514 let taps = vec![Complex::new(0.5, 0.0), Complex::new(0.5, 0.0)];
515 for deci in 1..=(input.len() + 1) {
516 let (mut src, src_out) = VectorSource::new(input.clone());
517 src.work()?;
518
519 eprintln!("Testing identity with decimation {deci}");
520 let (mut b, os) = FirFilter::builder(&taps).deci(deci).build(src_out);
521 if deci < input.len() {
522 assert!(matches![b.work()?, BlockRet::Again]);
523 }
524 assert!(matches![b.work()?, BlockRet::WaitForStream(_, _)]);
525 let (res, _) = os.read_buf()?;
526 let max = (input.len() - 1) / deci;
527 assert_almost_equal_complex(
528 res.slice(),
529 &[
530 Complex::new(1.5, 0.0),
531 Complex::new(2.5, 0.1),
532 Complex::new(3.55, 0.1),
533 Complex::new(4.55, 0.0),
534 Complex::new(5.5, 0.1),
535 ]
536 .into_iter()
537 .step_by(deci)
538 .take(max)
539 .collect::<Vec<_>>(),
540 );
541 }
542 Ok(())
543 }
544
545 #[test]
546 fn test_complex() {
547 let input = vec![
548 Complex::new(1.0, 0.0),
549 Complex::new(2.0, 0.0),
550 Complex::new(3.0, 0.2),
551 Complex::new(4.1, 0.0),
552 Complex::new(5.0, 0.0),
553 Complex::new(6.0, 0.2),
554 ];
555 let taps = vec![
556 Complex::new(0.1, 0.0),
557 Complex::new(1.0, 0.0),
558 Complex::new(0.0, 0.2),
559 ];
560 let filter = Fir::new(&taps);
561 assert_almost_equal_complex(
562 &filter.filter_n(&input, 1),
563 &[
564 Complex::new(2.3, 0.22),
565 Complex::new(3.41, 0.6),
566 Complex::new(4.56, 0.6),
567 Complex::new(5.6, 0.84),
568 ],
569 );
570 assert_almost_equal_complex(
571 &filter.filter_n(&input, 2),
572 &[Complex::new(2.3, 0.22), Complex::new(4.56, 0.6)],
573 );
574 }
575
576 #[test]
577 fn test_filter_generator() {
578 let taps = low_pass_complex(10000.0, 1000.0, 1000.0, &WindowType::Hamming);
579 assert_eq!(taps.len(), 25);
580 assert_almost_equal_complex(
581 &taps,
582 &[
583 Complex::new(0.002010403, 0.0),
584 Complex::new(0.0016210203, 0.0),
585 Complex::new(7.851862e-10, 0.0),
586 Complex::new(-0.0044467063, 0.0),
587 Complex::new(-0.011685465, 0.0),
588 Complex::new(-0.018134259, 0.0),
589 Complex::new(-0.016773716, 0.0),
590 Complex::new(-3.6538055e-9, 0.0),
591 Complex::new(0.0358771, 0.0),
592 Complex::new(0.08697697, 0.0),
593 Complex::new(0.14148787, 0.0),
594 Complex::new(0.18345332, 0.0),
595 Complex::new(0.19922684, 0.0),
596 Complex::new(0.1834533, 0.0),
597 Complex::new(0.14148785, 0.0),
598 Complex::new(0.08697697, 0.0),
599 Complex::new(0.035877097, 0.0),
600 Complex::new(-3.6538053e-9, 0.0),
601 Complex::new(-0.016773716, 0.0),
602 Complex::new(-0.018134257, 0.0),
603 Complex::new(-0.011685458, 0.0),
604 Complex::new(-0.0044467044, 0.0),
605 Complex::new(7.851859e-10, 0.0),
606 Complex::new(0.0016210207, 0.0),
607 Complex::new(0.002010403, 0.0),
608 ],
609 );
610 }
611}