pitch_detector/core/
fft_space.rs1use num_traits::Zero;
2use rustfft::num_complex::Complex;
3use std::borrow::Borrow;
4
5mod utils {
6 use rustfft::num_complex::Complex;
7 pub struct FreqDomainIter<'a> {
8 pub(super) complex_iter: std::slice::Iter<'a, Complex<f64>>,
9 pub(super) square_rooted: bool,
10 }
11
12 impl Iterator for FreqDomainIter<'_> {
13 type Item = (f64, f64);
14
15 fn next(&mut self) -> Option<Self::Item> {
16 match self.complex_iter.next() {
17 Some(complex) => {
18 let value = complex.norm_sqr();
19 let phase = complex.arg();
20 if self.square_rooted {
21 Some((value.sqrt(), phase))
22 } else {
23 Some((value, phase))
24 }
25 }
26 None => None,
27 }
28 }
29
30 fn size_hint(&self) -> (usize, Option<usize>) {
31 self.complex_iter.size_hint()
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
37pub struct FftSpace {
38 signal_len: usize,
39 space: Vec<Complex<f64>>,
40 scratch: Vec<Complex<f64>>,
41}
42
43impl FftSpace {
44 pub fn new(size: usize) -> Self {
45 let mut padded_size = (2usize).pow(10);
46 padded_size = loop {
47 if padded_size < size {
48 padded_size *= 2;
49 } else {
50 break padded_size;
51 }
52 };
53 FftSpace {
54 signal_len: size,
55 space: vec![Complex::zero(); padded_size],
56 scratch: vec![Complex::zero(); padded_size],
57 }
58 }
59
60 pub fn map<F: Fn(&Complex<f64>) -> Complex<f64>>(&mut self, map_fn: F) {
61 self.space.iter_mut().for_each(|f| {
62 *f = map_fn(f);
63 });
64 }
65
66 pub fn signal_len(&self) -> usize {
67 self.signal_len
68 }
69
70 pub fn padded_len(&self) -> usize {
71 self.space.len()
72 }
73
74 pub fn space(&self) -> &[Complex<f64>] {
75 &self.space
76 }
77
78 pub fn signal(&self) -> Box<dyn Iterator<Item = f64> + '_> {
79 Box::new(self.space[..self.signal_len].iter().map(|f| f.re))
80 }
81
82 pub fn workspace(&mut self) -> (&mut [Complex<f64>], &mut [Complex<f64>]) {
83 (&mut self.space, &mut self.scratch)
84 }
85
86 pub fn init_with_signal<I: IntoIterator>(&mut self, signal: I)
87 where
88 <I as IntoIterator>::Item: std::borrow::Borrow<f64>,
89 {
90 let signal_iter = signal.into_iter();
91 let signal_len = signal_iter
92 .size_hint()
93 .1
94 .expect("Signal length is not known");
95 assert!(signal_len <= self.space.len());
96 signal_iter
97 .zip(self.space.iter_mut())
98 .for_each(|(sample, fft)| {
99 fft.re = *sample.borrow();
100 fft.im = 0.0;
101 });
102 self.space[signal_len..]
103 .iter_mut()
104 .for_each(|o| *o = Complex::zero())
105 }
106
107 pub fn freq_domain(&self, square_rooted: bool) -> utils::FreqDomainIter {
108 utils::FreqDomainIter {
109 complex_iter: self.space.iter(),
110 square_rooted,
111 }
112 }
113}