1use std::collections::HashMap;
2use std::fmt;
3use std::sync::Arc;
4
5use neco_complex::Complex;
6
7use crate::dsp_float::DspFloat;
8use crate::internal_fft;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum FftError {
12 InputBuffer(usize, usize),
13 OutputBuffer(usize, usize),
14}
15
16impl fmt::Display for FftError {
17 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18 match self {
19 Self::InputBuffer(expected, got) => {
20 write!(
21 f,
22 "wrong input buffer length: expected {expected}, got {got}"
23 )
24 }
25 Self::OutputBuffer(expected, got) => {
26 write!(
27 f,
28 "wrong output buffer length: expected {expected}, got {got}"
29 )
30 }
31 }
32 }
33}
34
35pub trait RealToComplex<T>: Send + Sync {
36 fn process(&self, input: &mut [T], output: &mut [Complex<T>]) -> Result<(), FftError>;
37 fn make_input_vec(&self) -> Vec<T>;
38 fn make_output_vec(&self) -> Vec<Complex<T>>;
39 fn len(&self) -> usize;
40 fn is_empty(&self) -> bool {
41 self.len() == 0
42 }
43}
44
45pub trait ComplexToReal<T>: Send + Sync {
46 fn process(&self, input: &mut [Complex<T>], output: &mut [T]) -> Result<(), FftError>;
47 fn make_input_vec(&self) -> Vec<Complex<T>>;
48 fn make_output_vec(&self) -> Vec<T>;
49 fn len(&self) -> usize;
50 fn is_empty(&self) -> bool {
51 self.len() == 0
52 }
53}
54
55pub trait FftPlanner<T> {
56 fn plan_fft_forward(&mut self, len: usize) -> Arc<dyn RealToComplex<T>>;
57 fn plan_fft_inverse(&mut self, len: usize) -> Arc<dyn ComplexToReal<T>>;
58}
59
60struct InternalR2C<T> {
61 len: usize,
62 _marker: std::marker::PhantomData<T>,
63}
64
65impl<T> InternalR2C<T> {
66 fn new(len: usize) -> Self {
67 Self {
68 len,
69 _marker: std::marker::PhantomData,
70 }
71 }
72}
73
74impl<T> RealToComplex<T> for InternalR2C<T>
75where
76 T: DspFloat,
77{
78 fn process(&self, input: &mut [T], output: &mut [Complex<T>]) -> Result<(), FftError> {
79 if input.len() != self.len {
80 return Err(FftError::InputBuffer(self.len, input.len()));
81 }
82 let expected = self.len / 2 + 1;
83 if output.len() != expected {
84 return Err(FftError::OutputBuffer(expected, output.len()));
85 }
86 let spectrum = internal_fft::real_fft_forward(input);
87 output.copy_from_slice(&spectrum);
88 Ok(())
89 }
90
91 fn make_input_vec(&self) -> Vec<T> {
92 vec![T::zero(); self.len]
93 }
94
95 fn make_output_vec(&self) -> Vec<Complex<T>> {
96 vec![Complex::new(T::zero(), T::zero()); self.len / 2 + 1]
97 }
98
99 fn len(&self) -> usize {
100 self.len
101 }
102}
103
104struct InternalC2R<T> {
105 len: usize,
106 _marker: std::marker::PhantomData<T>,
107}
108
109impl<T> InternalC2R<T> {
110 fn new(len: usize) -> Self {
111 Self {
112 len,
113 _marker: std::marker::PhantomData,
114 }
115 }
116}
117
118impl<T> ComplexToReal<T> for InternalC2R<T>
119where
120 T: DspFloat,
121{
122 fn process(&self, input: &mut [Complex<T>], output: &mut [T]) -> Result<(), FftError> {
123 let expected_in = self.len / 2 + 1;
124 if input.len() != expected_in {
125 return Err(FftError::InputBuffer(expected_in, input.len()));
126 }
127 if output.len() != self.len {
128 return Err(FftError::OutputBuffer(self.len, output.len()));
129 }
130 internal_fft::real_fft_inverse(input, output);
131 Ok(())
132 }
133
134 fn make_input_vec(&self) -> Vec<Complex<T>> {
135 vec![Complex::new(T::zero(), T::zero()); self.len / 2 + 1]
136 }
137
138 fn make_output_vec(&self) -> Vec<T> {
139 vec![T::zero(); self.len]
140 }
141
142 fn len(&self) -> usize {
143 self.len
144 }
145}
146
147pub struct RustFftPlannerF32 {
148 r2c_cache: HashMap<usize, Arc<dyn RealToComplex<f32>>>,
149 c2r_cache: HashMap<usize, Arc<dyn ComplexToReal<f32>>>,
150}
151
152impl RustFftPlannerF32 {
153 pub fn new() -> Self {
154 Self {
155 r2c_cache: HashMap::new(),
156 c2r_cache: HashMap::new(),
157 }
158 }
159}
160
161impl Default for RustFftPlannerF32 {
162 fn default() -> Self {
163 Self::new()
164 }
165}
166
167impl FftPlanner<f32> for RustFftPlannerF32 {
168 fn plan_fft_forward(&mut self, len: usize) -> Arc<dyn RealToComplex<f32>> {
169 self.r2c_cache
170 .entry(len)
171 .or_insert_with(|| Arc::new(InternalR2C::<f32>::new(len)))
172 .clone()
173 }
174
175 fn plan_fft_inverse(&mut self, len: usize) -> Arc<dyn ComplexToReal<f32>> {
176 self.c2r_cache
177 .entry(len)
178 .or_insert_with(|| Arc::new(InternalC2R::<f32>::new(len)))
179 .clone()
180 }
181}
182
183pub struct RustFftPlannerF64 {
184 r2c_cache: HashMap<usize, Arc<dyn RealToComplex<f64>>>,
185 c2r_cache: HashMap<usize, Arc<dyn ComplexToReal<f64>>>,
186}
187
188impl RustFftPlannerF64 {
189 pub fn new() -> Self {
190 Self {
191 r2c_cache: HashMap::new(),
192 c2r_cache: HashMap::new(),
193 }
194 }
195}
196
197impl Default for RustFftPlannerF64 {
198 fn default() -> Self {
199 Self::new()
200 }
201}
202
203impl FftPlanner<f64> for RustFftPlannerF64 {
204 fn plan_fft_forward(&mut self, len: usize) -> Arc<dyn RealToComplex<f64>> {
205 self.r2c_cache
206 .entry(len)
207 .or_insert_with(|| Arc::new(InternalR2C::<f64>::new(len)))
208 .clone()
209 }
210
211 fn plan_fft_inverse(&mut self, len: usize) -> Arc<dyn ComplexToReal<f64>> {
212 self.c2r_cache
213 .entry(len)
214 .or_insert_with(|| Arc::new(InternalC2R::<f64>::new(len)))
215 .clone()
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn planner_roundtrip_f64_power_of_two() {
225 let mut planner = RustFftPlannerF64::new();
226 let fft_fwd = planner.plan_fft_forward(1024);
227 let fft_inv = planner.plan_fft_inverse(1024);
228
229 let input: Vec<f64> = (0..1024)
230 .map(|i| (2.0 * std::f64::consts::PI * 440.0 * i as f64 / 48000.0).sin())
231 .collect();
232
233 let mut buf = input.clone();
234 let mut spectrum = fft_fwd.make_output_vec();
235 fft_fwd.process(&mut buf, &mut spectrum).unwrap();
236
237 let mut output = fft_inv.make_output_vec();
238 fft_inv.process(&mut spectrum, &mut output).unwrap();
239
240 let scale = 1.0 / 1024.0;
241 let max_err = output
242 .iter()
243 .zip(input.iter())
244 .map(|(&o, &i)| (o * scale - i).abs())
245 .fold(0.0, f64::max);
246 assert!(max_err < 1e-10, "roundtrip error: {max_err:.2e}");
247 }
248
249 #[test]
250 fn planner_roundtrip_f64_non_power_of_two() {
251 let len = 1001;
252 let mut planner = RustFftPlannerF64::new();
253 let fft_fwd = planner.plan_fft_forward(len);
254 let fft_inv = planner.plan_fft_inverse(len);
255
256 let input: Vec<f64> = (0..len)
257 .map(|i| {
258 let t = i as f64 / len as f64;
259 (2.0 * std::f64::consts::PI * 7.0 * t).sin()
260 + 0.4 * (2.0 * std::f64::consts::PI * 19.0 * t).cos()
261 })
262 .collect();
263
264 let mut buf = input.clone();
265 let mut spectrum = fft_fwd.make_output_vec();
266 fft_fwd.process(&mut buf, &mut spectrum).unwrap();
267
268 let mut output = fft_inv.make_output_vec();
269 fft_inv.process(&mut spectrum, &mut output).unwrap();
270
271 let scale = 1.0 / len as f64;
272 let max_err = output
273 .iter()
274 .zip(input.iter())
275 .map(|(&o, &i)| (o * scale - i).abs())
276 .fold(0.0, f64::max);
277 assert!(max_err < 1e-9, "roundtrip error: {max_err:.2e}");
278 }
279
280 #[test]
281 fn planner_roundtrip_f32_non_power_of_two() {
282 let len = 777;
283 let mut planner = RustFftPlannerF32::new();
284 let fft_fwd = planner.plan_fft_forward(len);
285 let fft_inv = planner.plan_fft_inverse(len);
286
287 let input: Vec<f32> = (0..len)
288 .map(|i| {
289 let t = i as f32 / len as f32;
290 (2.0f32 * std::f32::consts::PI * 5.0 * t).sin()
291 + 0.25 * (2.0f32 * std::f32::consts::PI * 11.0 * t).cos()
292 })
293 .collect();
294
295 let mut buf = input.clone();
296 let mut spectrum = fft_fwd.make_output_vec();
297 fft_fwd.process(&mut buf, &mut spectrum).unwrap();
298
299 let mut output = fft_inv.make_output_vec();
300 fft_inv.process(&mut spectrum, &mut output).unwrap();
301
302 let scale = 1.0f32 / len as f32;
303 let max_err = output
304 .iter()
305 .zip(input.iter())
306 .map(|(&o, &i)| (o * scale - i).abs())
307 .fold(0.0, f32::max);
308 assert!(max_err < 5e-4, "roundtrip error: {max_err:.2e}");
309 }
310}