1use std::collections::HashMap;
4use std::sync::{Arc, LazyLock, Mutex, RwLock};
5
6use num_complex::Complex;
7use realfft::{ComplexToReal, RealFftPlanner, RealToComplex};
8use rustfft::{Fft, FftPlanner};
9
10use ferray_core::error::{FerrayError, FerrayResult};
11use ferray_core::{Array, Ix1};
12
13use crate::norm::{FftDirection, FftNorm};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17struct CacheKey {
18 size: usize,
19 inverse: bool,
20}
21
22type ComplexPlanMap<T> = RwLock<HashMap<CacheKey, Arc<dyn Fft<T>>>>;
42
43static F64_PLANS: LazyLock<ComplexPlanMap<f64>> = LazyLock::new(|| RwLock::new(HashMap::new()));
44static F64_PLANNER: LazyLock<Mutex<FftPlanner<f64>>> =
45 LazyLock::new(|| Mutex::new(FftPlanner::new()));
46static F32_PLANS: LazyLock<ComplexPlanMap<f32>> = LazyLock::new(|| RwLock::new(HashMap::new()));
47static F32_PLANNER: LazyLock<Mutex<FftPlanner<f32>>> =
48 LazyLock::new(|| Mutex::new(FftPlanner::new()));
49
50pub(crate) fn get_cached_plan_f64(size: usize, inverse: bool) -> Arc<dyn Fft<f64>> {
56 let key = CacheKey { size, inverse };
57 {
59 let plans = F64_PLANS.read().expect("f64 FFT plan cache poisoned");
60 if let Some(plan) = plans.get(&key) {
61 return plan.clone();
62 }
63 }
64 let plan: Arc<dyn Fft<f64>> = {
68 let mut planner = F64_PLANNER.lock().expect("f64 FFT planner poisoned");
69 if inverse {
70 planner.plan_fft_inverse(size)
71 } else {
72 planner.plan_fft_forward(size)
73 }
74 };
75 let mut plans = F64_PLANS.write().expect("f64 FFT plan cache poisoned");
76 plans.entry(key).or_insert(plan).clone()
77}
78
79pub(crate) fn get_cached_plan_f32(size: usize, inverse: bool) -> Arc<dyn Fft<f32>> {
81 let key = CacheKey { size, inverse };
82 {
83 let plans = F32_PLANS.read().expect("f32 FFT plan cache poisoned");
84 if let Some(plan) = plans.get(&key) {
85 return plan.clone();
86 }
87 }
88 let plan: Arc<dyn Fft<f32>> = {
89 let mut planner = F32_PLANNER.lock().expect("f32 FFT planner poisoned");
90 if inverse {
91 planner.plan_fft_inverse(size)
92 } else {
93 planner.plan_fft_forward(size)
94 }
95 };
96 let mut plans = F32_PLANS.write().expect("f32 FFT plan cache poisoned");
97 plans.entry(key).or_insert(plan).clone()
98}
99
100pub(crate) fn get_cached_plan(size: usize, inverse: bool) -> Arc<dyn Fft<f64>> {
104 get_cached_plan_f64(size, inverse)
105}
106
107type RealForwardMap<T> = RwLock<HashMap<usize, Arc<dyn RealToComplex<T>>>>;
117type RealInverseMap<T> = RwLock<HashMap<usize, Arc<dyn ComplexToReal<T>>>>;
118
119static REAL_F64_FORWARD: LazyLock<RealForwardMap<f64>> =
120 LazyLock::new(|| RwLock::new(HashMap::new()));
121static REAL_F64_INVERSE: LazyLock<RealInverseMap<f64>> =
122 LazyLock::new(|| RwLock::new(HashMap::new()));
123static REAL_F64_PLANNER: LazyLock<Mutex<RealFftPlanner<f64>>> =
124 LazyLock::new(|| Mutex::new(RealFftPlanner::new()));
125static REAL_F32_FORWARD: LazyLock<RealForwardMap<f32>> =
126 LazyLock::new(|| RwLock::new(HashMap::new()));
127static REAL_F32_INVERSE: LazyLock<RealInverseMap<f32>> =
128 LazyLock::new(|| RwLock::new(HashMap::new()));
129static REAL_F32_PLANNER: LazyLock<Mutex<RealFftPlanner<f32>>> =
130 LazyLock::new(|| Mutex::new(RealFftPlanner::new()));
131
132pub(crate) fn get_cached_real_forward_f64(size: usize) -> Arc<dyn RealToComplex<f64>> {
137 {
138 let plans = REAL_F64_FORWARD
139 .read()
140 .expect("f64 real-forward plan cache poisoned");
141 if let Some(plan) = plans.get(&size) {
142 return plan.clone();
143 }
144 }
145 let plan: Arc<dyn RealToComplex<f64>> = {
146 let mut planner = REAL_F64_PLANNER
147 .lock()
148 .expect("f64 real FFT planner poisoned");
149 planner.plan_fft_forward(size)
150 };
151 let mut plans = REAL_F64_FORWARD
152 .write()
153 .expect("f64 real-forward plan cache poisoned");
154 plans.entry(size).or_insert(plan).clone()
155}
156
157pub(crate) fn get_cached_real_inverse_f64(size: usize) -> Arc<dyn ComplexToReal<f64>> {
162 {
163 let plans = REAL_F64_INVERSE
164 .read()
165 .expect("f64 real-inverse plan cache poisoned");
166 if let Some(plan) = plans.get(&size) {
167 return plan.clone();
168 }
169 }
170 let plan: Arc<dyn ComplexToReal<f64>> = {
171 let mut planner = REAL_F64_PLANNER
172 .lock()
173 .expect("f64 real FFT planner poisoned");
174 planner.plan_fft_inverse(size)
175 };
176 let mut plans = REAL_F64_INVERSE
177 .write()
178 .expect("f64 real-inverse plan cache poisoned");
179 plans.entry(size).or_insert(plan).clone()
180}
181
182pub(crate) fn get_cached_real_forward_f32(size: usize) -> Arc<dyn RealToComplex<f32>> {
184 {
185 let plans = REAL_F32_FORWARD
186 .read()
187 .expect("f32 real-forward plan cache poisoned");
188 if let Some(plan) = plans.get(&size) {
189 return plan.clone();
190 }
191 }
192 let plan: Arc<dyn RealToComplex<f32>> = {
193 let mut planner = REAL_F32_PLANNER
194 .lock()
195 .expect("f32 real FFT planner poisoned");
196 planner.plan_fft_forward(size)
197 };
198 let mut plans = REAL_F32_FORWARD
199 .write()
200 .expect("f32 real-forward plan cache poisoned");
201 plans.entry(size).or_insert(plan).clone()
202}
203
204pub(crate) fn get_cached_real_inverse_f32(size: usize) -> Arc<dyn ComplexToReal<f32>> {
206 {
207 let plans = REAL_F32_INVERSE
208 .read()
209 .expect("f32 real-inverse plan cache poisoned");
210 if let Some(plan) = plans.get(&size) {
211 return plan.clone();
212 }
213 }
214 let plan: Arc<dyn ComplexToReal<f32>> = {
215 let mut planner = REAL_F32_PLANNER
216 .lock()
217 .expect("f32 real FFT planner poisoned");
218 planner.plan_fft_inverse(size)
219 };
220 let mut plans = REAL_F32_INVERSE
221 .write()
222 .expect("f32 real-inverse plan cache poisoned");
223 plans.entry(size).or_insert(plan).clone()
224}
225
226#[allow(dead_code)]
228pub(crate) fn get_cached_real_forward(size: usize) -> Arc<dyn RealToComplex<f64>> {
229 get_cached_real_forward_f64(size)
230}
231
232#[allow(dead_code)]
234pub(crate) fn get_cached_real_inverse(size: usize) -> Arc<dyn ComplexToReal<f64>> {
235 get_cached_real_inverse_f64(size)
236}
237
238pub struct FftPlan {
259 forward: Arc<dyn Fft<f64>>,
260 inverse: Arc<dyn Fft<f64>>,
261 size: usize,
262}
263
264impl FftPlan {
268 pub fn new(size: usize) -> FerrayResult<Self> {
277 if size == 0 {
278 return Err(FerrayError::invalid_value("FFT plan size must be > 0"));
279 }
280 let forward = get_cached_plan(size, false);
281 let inverse = get_cached_plan(size, true);
282 Ok(Self {
283 forward,
284 inverse,
285 size,
286 })
287 }
288
289 pub fn size(&self) -> usize {
291 self.size
292 }
293
294 pub fn execute(
303 &self,
304 signal: &Array<Complex<f64>, Ix1>,
305 ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
306 self.execute_with_norm(signal, FftNorm::Backward)
307 }
308
309 pub fn execute_with_norm(
315 &self,
316 signal: &Array<Complex<f64>, Ix1>,
317 norm: FftNorm,
318 ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
319 if signal.size() != self.size {
320 return Err(FerrayError::shape_mismatch(format!(
321 "signal length {} does not match plan size {}",
322 signal.size(),
323 self.size,
324 )));
325 }
326 let mut buffer: Vec<Complex<f64>> = signal.iter().copied().collect();
327 let mut scratch = vec![Complex::new(0.0, 0.0); self.forward.get_inplace_scratch_len()];
328 self.forward.process_with_scratch(&mut buffer, &mut scratch);
329
330 let scale = norm.scale_factor(self.size, FftDirection::Forward);
331 if (scale - 1.0).abs() > f64::EPSILON {
332 for c in &mut buffer {
333 *c *= scale;
334 }
335 }
336
337 Array::from_vec(Ix1::new([self.size]), buffer)
338 }
339
340 pub fn execute_inverse(
348 &self,
349 spectrum: &Array<Complex<f64>, Ix1>,
350 ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
351 self.execute_inverse_with_norm(spectrum, FftNorm::Backward)
352 }
353
354 pub fn execute_inverse_with_norm(
360 &self,
361 spectrum: &Array<Complex<f64>, Ix1>,
362 norm: FftNorm,
363 ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
364 if spectrum.size() != self.size {
365 return Err(FerrayError::shape_mismatch(format!(
366 "spectrum length {} does not match plan size {}",
367 spectrum.size(),
368 self.size,
369 )));
370 }
371 let mut buffer: Vec<Complex<f64>> = spectrum.iter().copied().collect();
372 let mut scratch = vec![Complex::new(0.0, 0.0); self.inverse.get_inplace_scratch_len()];
373 self.inverse.process_with_scratch(&mut buffer, &mut scratch);
374
375 let scale = norm.scale_factor(self.size, FftDirection::Inverse);
376 if (scale - 1.0).abs() > f64::EPSILON {
377 for c in &mut buffer {
378 *c *= scale;
379 }
380 }
381
382 Array::from_vec(Ix1::new([self.size]), buffer)
383 }
384}
385
386impl std::fmt::Debug for FftPlan {
387 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388 f.debug_struct("FftPlan").field("size", &self.size).finish()
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395
396 #[test]
397 fn plan_new_valid() {
398 let plan = FftPlan::new(8).unwrap();
399 assert_eq!(plan.size(), 8);
400 }
401
402 #[test]
403 fn plan_new_zero_errors() {
404 assert!(FftPlan::new(0).is_err());
405 }
406
407 #[test]
408 fn plan_execute_roundtrip() {
409 let plan = FftPlan::new(4).unwrap();
410 let data = vec![
411 Complex::new(1.0, 0.0),
412 Complex::new(2.0, 0.0),
413 Complex::new(3.0, 0.0),
414 Complex::new(4.0, 0.0),
415 ];
416 let signal = Array::<Complex<f64>, Ix1>::from_vec(Ix1::new([4]), data.clone()).unwrap();
417
418 let spectrum = plan.execute(&signal).unwrap();
419 let recovered = plan.execute_inverse(&spectrum).unwrap();
420
421 for (orig, rec) in data.iter().zip(recovered.iter()) {
422 assert!((orig.re - rec.re).abs() < 1e-12);
423 assert!((orig.im - rec.im).abs() < 1e-12);
424 }
425 }
426
427 #[test]
428 fn plan_size_mismatch() {
429 let plan = FftPlan::new(8).unwrap();
430 let signal =
431 Array::<Complex<f64>, Ix1>::from_vec(Ix1::new([4]), vec![Complex::new(0.0, 0.0); 4])
432 .unwrap();
433 assert!(plan.execute(&signal).is_err());
434 }
435
436 #[test]
437 fn cached_plan_reuse() {
438 let p1 = get_cached_plan(16, false);
440 let p2 = get_cached_plan(16, false);
441 assert!(Arc::ptr_eq(&p1, &p2));
442 }
443
444 #[test]
445 fn plan_is_send_sync() {
446 fn assert_send_sync<T: Send + Sync>() {}
447 assert_send_sync::<FftPlan>();
448 }
449}