Skip to main content

neco_stft/
fft_backend.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::sync::Arc;
4
5use neco_complex::Complex;
6
7use crate::dsp_float::DspFloat;
8use crate::internal_fft;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum FftError {
12    InputBuffer(usize, usize),
13    OutputBuffer(usize, usize),
14}
15
16impl fmt::Display for FftError {
17    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18        match self {
19            Self::InputBuffer(expected, got) => {
20                write!(
21                    f,
22                    "wrong input buffer length: expected {expected}, got {got}"
23                )
24            }
25            Self::OutputBuffer(expected, got) => {
26                write!(
27                    f,
28                    "wrong output buffer length: expected {expected}, got {got}"
29                )
30            }
31        }
32    }
33}
34
35pub trait RealToComplex<T>: Send + Sync {
36    fn process(&self, input: &mut [T], output: &mut [Complex<T>]) -> Result<(), FftError>;
37    fn make_input_vec(&self) -> Vec<T>;
38    fn make_output_vec(&self) -> Vec<Complex<T>>;
39    fn len(&self) -> usize;
40    fn is_empty(&self) -> bool {
41        self.len() == 0
42    }
43}
44
45pub trait ComplexToReal<T>: Send + Sync {
46    fn process(&self, input: &mut [Complex<T>], output: &mut [T]) -> Result<(), FftError>;
47    fn make_input_vec(&self) -> Vec<Complex<T>>;
48    fn make_output_vec(&self) -> Vec<T>;
49    fn len(&self) -> usize;
50    fn is_empty(&self) -> bool {
51        self.len() == 0
52    }
53}
54
55pub trait FftPlanner<T> {
56    fn plan_fft_forward(&mut self, len: usize) -> Arc<dyn RealToComplex<T>>;
57    fn plan_fft_inverse(&mut self, len: usize) -> Arc<dyn ComplexToReal<T>>;
58}
59
60struct InternalR2C<T> {
61    len: usize,
62    _marker: std::marker::PhantomData<T>,
63}
64
65impl<T> InternalR2C<T> {
66    fn new(len: usize) -> Self {
67        Self {
68            len,
69            _marker: std::marker::PhantomData,
70        }
71    }
72}
73
74impl<T> RealToComplex<T> for InternalR2C<T>
75where
76    T: DspFloat,
77{
78    fn process(&self, input: &mut [T], output: &mut [Complex<T>]) -> Result<(), FftError> {
79        if input.len() != self.len {
80            return Err(FftError::InputBuffer(self.len, input.len()));
81        }
82        let expected = self.len / 2 + 1;
83        if output.len() != expected {
84            return Err(FftError::OutputBuffer(expected, output.len()));
85        }
86        let spectrum = internal_fft::real_fft_forward(input);
87        output.copy_from_slice(&spectrum);
88        Ok(())
89    }
90
91    fn make_input_vec(&self) -> Vec<T> {
92        vec![T::zero(); self.len]
93    }
94
95    fn make_output_vec(&self) -> Vec<Complex<T>> {
96        vec![Complex::new(T::zero(), T::zero()); self.len / 2 + 1]
97    }
98
99    fn len(&self) -> usize {
100        self.len
101    }
102}
103
104struct InternalC2R<T> {
105    len: usize,
106    _marker: std::marker::PhantomData<T>,
107}
108
109impl<T> InternalC2R<T> {
110    fn new(len: usize) -> Self {
111        Self {
112            len,
113            _marker: std::marker::PhantomData,
114        }
115    }
116}
117
118impl<T> ComplexToReal<T> for InternalC2R<T>
119where
120    T: DspFloat,
121{
122    fn process(&self, input: &mut [Complex<T>], output: &mut [T]) -> Result<(), FftError> {
123        let expected_in = self.len / 2 + 1;
124        if input.len() != expected_in {
125            return Err(FftError::InputBuffer(expected_in, input.len()));
126        }
127        if output.len() != self.len {
128            return Err(FftError::OutputBuffer(self.len, output.len()));
129        }
130        internal_fft::real_fft_inverse(input, output);
131        Ok(())
132    }
133
134    fn make_input_vec(&self) -> Vec<Complex<T>> {
135        vec![Complex::new(T::zero(), T::zero()); self.len / 2 + 1]
136    }
137
138    fn make_output_vec(&self) -> Vec<T> {
139        vec![T::zero(); self.len]
140    }
141
142    fn len(&self) -> usize {
143        self.len
144    }
145}
146
147pub struct RustFftPlannerF32 {
148    r2c_cache: HashMap<usize, Arc<dyn RealToComplex<f32>>>,
149    c2r_cache: HashMap<usize, Arc<dyn ComplexToReal<f32>>>,
150}
151
152impl RustFftPlannerF32 {
153    pub fn new() -> Self {
154        Self {
155            r2c_cache: HashMap::new(),
156            c2r_cache: HashMap::new(),
157        }
158    }
159}
160
161impl Default for RustFftPlannerF32 {
162    fn default() -> Self {
163        Self::new()
164    }
165}
166
167impl FftPlanner<f32> for RustFftPlannerF32 {
168    fn plan_fft_forward(&mut self, len: usize) -> Arc<dyn RealToComplex<f32>> {
169        self.r2c_cache
170            .entry(len)
171            .or_insert_with(|| Arc::new(InternalR2C::<f32>::new(len)))
172            .clone()
173    }
174
175    fn plan_fft_inverse(&mut self, len: usize) -> Arc<dyn ComplexToReal<f32>> {
176        self.c2r_cache
177            .entry(len)
178            .or_insert_with(|| Arc::new(InternalC2R::<f32>::new(len)))
179            .clone()
180    }
181}
182
183pub struct RustFftPlannerF64 {
184    r2c_cache: HashMap<usize, Arc<dyn RealToComplex<f64>>>,
185    c2r_cache: HashMap<usize, Arc<dyn ComplexToReal<f64>>>,
186}
187
188impl RustFftPlannerF64 {
189    pub fn new() -> Self {
190        Self {
191            r2c_cache: HashMap::new(),
192            c2r_cache: HashMap::new(),
193        }
194    }
195}
196
197impl Default for RustFftPlannerF64 {
198    fn default() -> Self {
199        Self::new()
200    }
201}
202
203impl FftPlanner<f64> for RustFftPlannerF64 {
204    fn plan_fft_forward(&mut self, len: usize) -> Arc<dyn RealToComplex<f64>> {
205        self.r2c_cache
206            .entry(len)
207            .or_insert_with(|| Arc::new(InternalR2C::<f64>::new(len)))
208            .clone()
209    }
210
211    fn plan_fft_inverse(&mut self, len: usize) -> Arc<dyn ComplexToReal<f64>> {
212        self.c2r_cache
213            .entry(len)
214            .or_insert_with(|| Arc::new(InternalC2R::<f64>::new(len)))
215            .clone()
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn planner_roundtrip_f64_power_of_two() {
225        let mut planner = RustFftPlannerF64::new();
226        let fft_fwd = planner.plan_fft_forward(1024);
227        let fft_inv = planner.plan_fft_inverse(1024);
228
229        let input: Vec<f64> = (0..1024)
230            .map(|i| (2.0 * std::f64::consts::PI * 440.0 * i as f64 / 48000.0).sin())
231            .collect();
232
233        let mut buf = input.clone();
234        let mut spectrum = fft_fwd.make_output_vec();
235        fft_fwd.process(&mut buf, &mut spectrum).unwrap();
236
237        let mut output = fft_inv.make_output_vec();
238        fft_inv.process(&mut spectrum, &mut output).unwrap();
239
240        let scale = 1.0 / 1024.0;
241        let max_err = output
242            .iter()
243            .zip(input.iter())
244            .map(|(&o, &i)| (o * scale - i).abs())
245            .fold(0.0, f64::max);
246        assert!(max_err < 1e-10, "roundtrip error: {max_err:.2e}");
247    }
248
249    #[test]
250    fn planner_roundtrip_f64_non_power_of_two() {
251        let len = 1001;
252        let mut planner = RustFftPlannerF64::new();
253        let fft_fwd = planner.plan_fft_forward(len);
254        let fft_inv = planner.plan_fft_inverse(len);
255
256        let input: Vec<f64> = (0..len)
257            .map(|i| {
258                let t = i as f64 / len as f64;
259                (2.0 * std::f64::consts::PI * 7.0 * t).sin()
260                    + 0.4 * (2.0 * std::f64::consts::PI * 19.0 * t).cos()
261            })
262            .collect();
263
264        let mut buf = input.clone();
265        let mut spectrum = fft_fwd.make_output_vec();
266        fft_fwd.process(&mut buf, &mut spectrum).unwrap();
267
268        let mut output = fft_inv.make_output_vec();
269        fft_inv.process(&mut spectrum, &mut output).unwrap();
270
271        let scale = 1.0 / len as f64;
272        let max_err = output
273            .iter()
274            .zip(input.iter())
275            .map(|(&o, &i)| (o * scale - i).abs())
276            .fold(0.0, f64::max);
277        assert!(max_err < 1e-9, "roundtrip error: {max_err:.2e}");
278    }
279
280    #[test]
281    fn planner_roundtrip_f32_non_power_of_two() {
282        let len = 777;
283        let mut planner = RustFftPlannerF32::new();
284        let fft_fwd = planner.plan_fft_forward(len);
285        let fft_inv = planner.plan_fft_inverse(len);
286
287        let input: Vec<f32> = (0..len)
288            .map(|i| {
289                let t = i as f32 / len as f32;
290                (2.0f32 * std::f32::consts::PI * 5.0 * t).sin()
291                    + 0.25 * (2.0f32 * std::f32::consts::PI * 11.0 * t).cos()
292            })
293            .collect();
294
295        let mut buf = input.clone();
296        let mut spectrum = fft_fwd.make_output_vec();
297        fft_fwd.process(&mut buf, &mut spectrum).unwrap();
298
299        let mut output = fft_inv.make_output_vec();
300        fft_inv.process(&mut spectrum, &mut output).unwrap();
301
302        let scale = 1.0f32 / len as f32;
303        let max_err = output
304            .iter()
305            .zip(input.iter())
306            .map(|(&o, &i)| (o * scale - i).abs())
307            .fold(0.0, f32::max);
308        assert!(max_err < 5e-4, "roundtrip error: {max_err:.2e}");
309    }
310}