Skip to main content

ferray_fft/
plan.rs

1// ferray-fft: FftPlan type and global plan cache (REQ-12, REQ-13)
2
3use std::collections::HashMap;
4use std::sync::{Arc, LazyLock, Mutex, RwLock};
5
6use num_complex::Complex;
7use realfft::{ComplexToReal, RealFftPlanner, RealToComplex};
8use rustfft::{Fft, FftPlanner};
9
10use ferray_core::error::{FerrayError, FerrayResult};
11use ferray_core::{Array, Ix1};
12
13use crate::norm::{FftDirection, FftNorm};
14
15/// Key for the global FFT plan cache: (transform size, is_inverse).
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17struct CacheKey {
18    size: usize,
19    inverse: bool,
20}
21
22// ---------------------------------------------------------------------------
23// Complex FFT plan caches (f32 and f64)
24//
25// The cache splits into two locks (#431):
26//
27//   * `RwLock<HashMap<key, Arc<plan>>>` — held in *read* mode for the
28//     overwhelmingly common cache-hit lookup. Many concurrent FFTs of
29//     already-cached sizes can read in parallel without serializing.
30//   * `Mutex<FftPlanner<T>>` — held only on the (rare) cache miss path,
31//     because `FftPlanner::plan_fft_forward` requires `&mut self`.
32//
33// The previous design wrapped both fields in a single `Mutex`, so even
34// the lookup path serialized every concurrent caller; under heavy
35// multi-threaded load that produced false sharing on what should be a
36// read-mostly path.
37// ---------------------------------------------------------------------------
38
39// Aliases keep clippy's `type_complexity` lint quiet on the `LazyLock`
40// statics below.
41type ComplexPlanMap<T> = RwLock<HashMap<CacheKey, Arc<dyn Fft<T>>>>;
42
43static F64_PLANS: LazyLock<ComplexPlanMap<f64>> = LazyLock::new(|| RwLock::new(HashMap::new()));
44static F64_PLANNER: LazyLock<Mutex<FftPlanner<f64>>> =
45    LazyLock::new(|| Mutex::new(FftPlanner::new()));
46static F32_PLANS: LazyLock<ComplexPlanMap<f32>> = LazyLock::new(|| RwLock::new(HashMap::new()));
47static F32_PLANNER: LazyLock<Mutex<FftPlanner<f32>>> =
48    LazyLock::new(|| Mutex::new(FftPlanner::new()));
49
50/// Obtain a cached f64 FFT plan for the given size and direction.
51///
52/// This is the primary internal entry point used by all f64 FFT functions.
53/// Plans are cached globally so repeated transforms of the same size
54/// reuse the same plan.
55pub(crate) fn get_cached_plan_f64(size: usize, inverse: bool) -> Arc<dyn Fft<f64>> {
56    let key = CacheKey { size, inverse };
57    // Fast path: read lock the plan map and return an existing plan.
58    {
59        let plans = F64_PLANS.read().expect("f64 FFT plan cache poisoned");
60        if let Some(plan) = plans.get(&key) {
61            return plan.clone();
62        }
63    }
64    // Cache miss: build a new plan with the planner Mutex, then insert
65    // it under the write lock. Two concurrent misses for the same key
66    // are handled by the second one finding the entry already populated.
67    let plan: Arc<dyn Fft<f64>> = {
68        let mut planner = F64_PLANNER.lock().expect("f64 FFT planner poisoned");
69        if inverse {
70            planner.plan_fft_inverse(size)
71        } else {
72            planner.plan_fft_forward(size)
73        }
74    };
75    let mut plans = F64_PLANS.write().expect("f64 FFT plan cache poisoned");
76    plans.entry(key).or_insert(plan).clone()
77}
78
79/// Obtain a cached f32 FFT plan for the given size and direction.
80pub(crate) fn get_cached_plan_f32(size: usize, inverse: bool) -> Arc<dyn Fft<f32>> {
81    let key = CacheKey { size, inverse };
82    {
83        let plans = F32_PLANS.read().expect("f32 FFT plan cache poisoned");
84        if let Some(plan) = plans.get(&key) {
85            return plan.clone();
86        }
87    }
88    let plan: Arc<dyn Fft<f32>> = {
89        let mut planner = F32_PLANNER.lock().expect("f32 FFT planner poisoned");
90        if inverse {
91            planner.plan_fft_inverse(size)
92        } else {
93            planner.plan_fft_forward(size)
94        }
95    };
96    let mut plans = F32_PLANS.write().expect("f32 FFT plan cache poisoned");
97    plans.entry(key).or_insert(plan).clone()
98}
99
100/// Legacy alias for f64 plan lookup — preserved for tests and any call
101/// sites that haven't been generified yet. New code should go through
102/// [`crate::float::FftFloat::cached_plan`].
103pub(crate) fn get_cached_plan(size: usize, inverse: bool) -> Arc<dyn Fft<f64>> {
104    get_cached_plan_f64(size, inverse)
105}
106
107// ---------------------------------------------------------------------------
108// Real-FFT plan caches (issues #432, #426, #431)
109//
110// Same RwLock-plus-Mutex split as the complex caches above. realfft
111// uses a separate planner type from rustfft because real-to-complex
112// (RealToComplex<T>) and complex-to-real (ComplexToReal<T>) are
113// distinct trait objects, so each direction gets its own plan map.
114// ---------------------------------------------------------------------------
115
116type RealForwardMap<T> = RwLock<HashMap<usize, Arc<dyn RealToComplex<T>>>>;
117type RealInverseMap<T> = RwLock<HashMap<usize, Arc<dyn ComplexToReal<T>>>>;
118
119static REAL_F64_FORWARD: LazyLock<RealForwardMap<f64>> =
120    LazyLock::new(|| RwLock::new(HashMap::new()));
121static REAL_F64_INVERSE: LazyLock<RealInverseMap<f64>> =
122    LazyLock::new(|| RwLock::new(HashMap::new()));
123static REAL_F64_PLANNER: LazyLock<Mutex<RealFftPlanner<f64>>> =
124    LazyLock::new(|| Mutex::new(RealFftPlanner::new()));
125static REAL_F32_FORWARD: LazyLock<RealForwardMap<f32>> =
126    LazyLock::new(|| RwLock::new(HashMap::new()));
127static REAL_F32_INVERSE: LazyLock<RealInverseMap<f32>> =
128    LazyLock::new(|| RwLock::new(HashMap::new()));
129static REAL_F32_PLANNER: LazyLock<Mutex<RealFftPlanner<f32>>> =
130    LazyLock::new(|| Mutex::new(RealFftPlanner::new()));
131
132/// Obtain a cached f64 real-to-complex FFT plan for the given size.
133///
134/// Returns a plan that consumes a real input of length `size` and
135/// produces `size/2 + 1` complex values.
136pub(crate) fn get_cached_real_forward_f64(size: usize) -> Arc<dyn RealToComplex<f64>> {
137    {
138        let plans = REAL_F64_FORWARD
139            .read()
140            .expect("f64 real-forward plan cache poisoned");
141        if let Some(plan) = plans.get(&size) {
142            return plan.clone();
143        }
144    }
145    let plan: Arc<dyn RealToComplex<f64>> = {
146        let mut planner = REAL_F64_PLANNER
147            .lock()
148            .expect("f64 real FFT planner poisoned");
149        planner.plan_fft_forward(size)
150    };
151    let mut plans = REAL_F64_FORWARD
152        .write()
153        .expect("f64 real-forward plan cache poisoned");
154    plans.entry(size).or_insert(plan).clone()
155}
156
157/// Obtain a cached f64 complex-to-real FFT plan for the given output size.
158///
159/// The argument `size` is the **real-output** length (not the complex
160/// input length, which is `size/2 + 1`).
161pub(crate) fn get_cached_real_inverse_f64(size: usize) -> Arc<dyn ComplexToReal<f64>> {
162    {
163        let plans = REAL_F64_INVERSE
164            .read()
165            .expect("f64 real-inverse plan cache poisoned");
166        if let Some(plan) = plans.get(&size) {
167            return plan.clone();
168        }
169    }
170    let plan: Arc<dyn ComplexToReal<f64>> = {
171        let mut planner = REAL_F64_PLANNER
172            .lock()
173            .expect("f64 real FFT planner poisoned");
174        planner.plan_fft_inverse(size)
175    };
176    let mut plans = REAL_F64_INVERSE
177        .write()
178        .expect("f64 real-inverse plan cache poisoned");
179    plans.entry(size).or_insert(plan).clone()
180}
181
182/// Obtain a cached f32 real-to-complex FFT plan for the given size.
183pub(crate) fn get_cached_real_forward_f32(size: usize) -> Arc<dyn RealToComplex<f32>> {
184    {
185        let plans = REAL_F32_FORWARD
186            .read()
187            .expect("f32 real-forward plan cache poisoned");
188        if let Some(plan) = plans.get(&size) {
189            return plan.clone();
190        }
191    }
192    let plan: Arc<dyn RealToComplex<f32>> = {
193        let mut planner = REAL_F32_PLANNER
194            .lock()
195            .expect("f32 real FFT planner poisoned");
196        planner.plan_fft_forward(size)
197    };
198    let mut plans = REAL_F32_FORWARD
199        .write()
200        .expect("f32 real-forward plan cache poisoned");
201    plans.entry(size).or_insert(plan).clone()
202}
203
204/// Obtain a cached f32 complex-to-real FFT plan for the given output size.
205pub(crate) fn get_cached_real_inverse_f32(size: usize) -> Arc<dyn ComplexToReal<f32>> {
206    {
207        let plans = REAL_F32_INVERSE
208            .read()
209            .expect("f32 real-inverse plan cache poisoned");
210        if let Some(plan) = plans.get(&size) {
211            return plan.clone();
212        }
213    }
214    let plan: Arc<dyn ComplexToReal<f32>> = {
215        let mut planner = REAL_F32_PLANNER
216            .lock()
217            .expect("f32 real FFT planner poisoned");
218        planner.plan_fft_inverse(size)
219    };
220    let mut plans = REAL_F32_INVERSE
221        .write()
222        .expect("f32 real-inverse plan cache poisoned");
223    plans.entry(size).or_insert(plan).clone()
224}
225
226/// Legacy f64 alias for the real-forward cache.
227#[allow(dead_code)]
228pub(crate) fn get_cached_real_forward(size: usize) -> Arc<dyn RealToComplex<f64>> {
229    get_cached_real_forward_f64(size)
230}
231
232/// Legacy f64 alias for the real-inverse cache.
233#[allow(dead_code)]
234pub(crate) fn get_cached_real_inverse(size: usize) -> Arc<dyn ComplexToReal<f64>> {
235    get_cached_real_inverse_f64(size)
236}
237
238/// A reusable FFT plan for a specific transform size.
239///
240/// `FftPlan` caches the internal FFT algorithm for a given size,
241/// enabling efficient repeated transforms. Plans are `Send + Sync`
242/// and can be shared across threads.
243///
244/// # Example
245/// ```
246/// use ferray_fft::FftPlan;
247/// use ferray_core::{Array, Ix1};
248/// use num_complex::Complex;
249///
250/// let plan = FftPlan::new(8).unwrap();
251/// let signal = Array::<Complex<f64>, Ix1>::from_vec(
252///     Ix1::new([8]),
253///     vec![Complex::new(1.0, 0.0); 8],
254/// ).unwrap();
255/// let result = plan.execute(&signal).unwrap();
256/// assert_eq!(result.shape(), &[8]);
257/// ```
258pub struct FftPlan {
259    forward: Arc<dyn Fft<f64>>,
260    inverse: Arc<dyn Fft<f64>>,
261    size: usize,
262}
263
264// FftPlan is Send + Sync because Arc<dyn Fft<f64>> is Send + Sync
265// (rustfft plans are thread-safe). No manual unsafe impl needed.
266
267impl FftPlan {
268    /// Create a new FFT plan for the given transform size.
269    ///
270    /// The plan pre-computes the internal FFT algorithm so that
271    /// subsequent calls to [`execute`](Self::execute) and
272    /// [`execute_inverse`](Self::execute_inverse) are fast.
273    ///
274    /// # Errors
275    /// Returns `FerrayError::InvalidValue` if `size` is 0.
276    pub fn new(size: usize) -> FerrayResult<Self> {
277        if size == 0 {
278            return Err(FerrayError::invalid_value("FFT plan size must be > 0"));
279        }
280        let forward = get_cached_plan(size, false);
281        let inverse = get_cached_plan(size, true);
282        Ok(Self {
283            forward,
284            inverse,
285            size,
286        })
287    }
288
289    /// Return the transform size this plan was created for.
290    pub fn size(&self) -> usize {
291        self.size
292    }
293
294    /// Execute a forward FFT on the given signal.
295    ///
296    /// The input array must have exactly `self.size()` elements.
297    /// Uses `FftNorm::Backward` (no scaling on forward).
298    ///
299    /// # Errors
300    /// Returns `FerrayError::ShapeMismatch` if the input length
301    /// does not match the plan size.
302    pub fn execute(
303        &self,
304        signal: &Array<Complex<f64>, Ix1>,
305    ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
306        self.execute_with_norm(signal, FftNorm::Backward)
307    }
308
309    /// Execute a forward FFT with the specified normalization.
310    ///
311    /// # Errors
312    /// Returns `FerrayError::ShapeMismatch` if the input length
313    /// does not match the plan size.
314    pub fn execute_with_norm(
315        &self,
316        signal: &Array<Complex<f64>, Ix1>,
317        norm: FftNorm,
318    ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
319        if signal.size() != self.size {
320            return Err(FerrayError::shape_mismatch(format!(
321                "signal length {} does not match plan size {}",
322                signal.size(),
323                self.size,
324            )));
325        }
326        let mut buffer: Vec<Complex<f64>> = signal.iter().copied().collect();
327        let mut scratch = vec![Complex::new(0.0, 0.0); self.forward.get_inplace_scratch_len()];
328        self.forward.process_with_scratch(&mut buffer, &mut scratch);
329
330        let scale = norm.scale_factor(self.size, FftDirection::Forward);
331        if (scale - 1.0).abs() > f64::EPSILON {
332            for c in &mut buffer {
333                *c *= scale;
334            }
335        }
336
337        Array::from_vec(Ix1::new([self.size]), buffer)
338    }
339
340    /// Execute an inverse FFT on the given spectrum.
341    ///
342    /// Uses `FftNorm::Backward` (divides by `n` on inverse).
343    ///
344    /// # Errors
345    /// Returns `FerrayError::ShapeMismatch` if the input length
346    /// does not match the plan size.
347    pub fn execute_inverse(
348        &self,
349        spectrum: &Array<Complex<f64>, Ix1>,
350    ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
351        self.execute_inverse_with_norm(spectrum, FftNorm::Backward)
352    }
353
354    /// Execute an inverse FFT with the specified normalization.
355    ///
356    /// # Errors
357    /// Returns `FerrayError::ShapeMismatch` if the input length
358    /// does not match the plan size.
359    pub fn execute_inverse_with_norm(
360        &self,
361        spectrum: &Array<Complex<f64>, Ix1>,
362        norm: FftNorm,
363    ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
364        if spectrum.size() != self.size {
365            return Err(FerrayError::shape_mismatch(format!(
366                "spectrum length {} does not match plan size {}",
367                spectrum.size(),
368                self.size,
369            )));
370        }
371        let mut buffer: Vec<Complex<f64>> = spectrum.iter().copied().collect();
372        let mut scratch = vec![Complex::new(0.0, 0.0); self.inverse.get_inplace_scratch_len()];
373        self.inverse.process_with_scratch(&mut buffer, &mut scratch);
374
375        let scale = norm.scale_factor(self.size, FftDirection::Inverse);
376        if (scale - 1.0).abs() > f64::EPSILON {
377            for c in &mut buffer {
378                *c *= scale;
379            }
380        }
381
382        Array::from_vec(Ix1::new([self.size]), buffer)
383    }
384}
385
386impl std::fmt::Debug for FftPlan {
387    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388        f.debug_struct("FftPlan").field("size", &self.size).finish()
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395
396    #[test]
397    fn plan_new_valid() {
398        let plan = FftPlan::new(8).unwrap();
399        assert_eq!(plan.size(), 8);
400    }
401
402    #[test]
403    fn plan_new_zero_errors() {
404        assert!(FftPlan::new(0).is_err());
405    }
406
407    #[test]
408    fn plan_execute_roundtrip() {
409        let plan = FftPlan::new(4).unwrap();
410        let data = vec![
411            Complex::new(1.0, 0.0),
412            Complex::new(2.0, 0.0),
413            Complex::new(3.0, 0.0),
414            Complex::new(4.0, 0.0),
415        ];
416        let signal = Array::<Complex<f64>, Ix1>::from_vec(Ix1::new([4]), data.clone()).unwrap();
417
418        let spectrum = plan.execute(&signal).unwrap();
419        let recovered = plan.execute_inverse(&spectrum).unwrap();
420
421        for (orig, rec) in data.iter().zip(recovered.iter()) {
422            assert!((orig.re - rec.re).abs() < 1e-12);
423            assert!((orig.im - rec.im).abs() < 1e-12);
424        }
425    }
426
427    #[test]
428    fn plan_size_mismatch() {
429        let plan = FftPlan::new(8).unwrap();
430        let signal =
431            Array::<Complex<f64>, Ix1>::from_vec(Ix1::new([4]), vec![Complex::new(0.0, 0.0); 4])
432                .unwrap();
433        assert!(plan.execute(&signal).is_err());
434    }
435
436    #[test]
437    fn cached_plan_reuse() {
438        // Getting the same plan twice should return the same Arc
439        let p1 = get_cached_plan(16, false);
440        let p2 = get_cached_plan(16, false);
441        assert!(Arc::ptr_eq(&p1, &p2));
442    }
443
444    #[test]
445    fn plan_is_send_sync() {
446        fn assert_send_sync<T: Send + Sync>() {}
447        assert_send_sync::<FftPlan>();
448    }
449}