Skip to main content

ferray_fft/
plan.rs

1// ferray-fft: FftPlan type and global plan cache (REQ-12, REQ-13)
2
3use std::collections::HashMap;
4use std::sync::{Arc, LazyLock, Mutex};
5
6use num_complex::Complex;
7use rustfft::{Fft, FftPlanner};
8
9use ferray_core::error::{FerrayError, FerrayResult};
10use ferray_core::{Array, Ix1};
11
12use crate::norm::{FftDirection, FftNorm};
13
14/// Key for the global FFT plan cache: (transform size, is_inverse).
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16struct CacheKey {
17    size: usize,
18    inverse: bool,
19}
20
21/// Global plan cache: maps (size, direction) to a reusable FFT plan.
22///
23/// Thread-safe via `Mutex`. Plans are `Arc`-wrapped so they can be
24/// shared across threads without copying.
25static GLOBAL_CACHE: LazyLock<Mutex<PlanCache>> = LazyLock::new(|| Mutex::new(PlanCache::new()));
26
27struct PlanCache {
28    planner: FftPlanner<f64>,
29    plans: HashMap<CacheKey, Arc<dyn Fft<f64>>>,
30}
31
32impl PlanCache {
33    fn new() -> Self {
34        Self {
35            planner: FftPlanner::new(),
36            plans: HashMap::new(),
37        }
38    }
39
40    fn get_plan(&mut self, size: usize, inverse: bool) -> Arc<dyn Fft<f64>> {
41        let key = CacheKey { size, inverse };
42        self.plans
43            .entry(key)
44            .or_insert_with(|| {
45                if inverse {
46                    self.planner.plan_fft_inverse(size)
47                } else {
48                    self.planner.plan_fft_forward(size)
49                }
50            })
51            .clone()
52    }
53}
54
55/// Obtain a cached FFT plan for the given size and direction.
56///
57/// This is the primary internal entry point used by all FFT functions.
58/// Plans are cached globally so repeated transforms of the same size
59/// reuse the same plan.
60pub(crate) fn get_cached_plan(size: usize, inverse: bool) -> Arc<dyn Fft<f64>> {
61    let mut cache = GLOBAL_CACHE.lock().expect("FFT plan cache lock poisoned");
62    cache.get_plan(size, inverse)
63}
64
65/// A reusable FFT plan for a specific transform size.
66///
67/// `FftPlan` caches the internal FFT algorithm for a given size,
68/// enabling efficient repeated transforms. Plans are `Send + Sync`
69/// and can be shared across threads.
70///
71/// # Example
72/// ```
73/// use ferray_fft::FftPlan;
74/// use ferray_core::{Array, Ix1};
75/// use num_complex::Complex;
76///
77/// let plan = FftPlan::new(8).unwrap();
78/// let signal = Array::<Complex<f64>, Ix1>::from_vec(
79///     Ix1::new([8]),
80///     vec![Complex::new(1.0, 0.0); 8],
81/// ).unwrap();
82/// let result = plan.execute(&signal).unwrap();
83/// assert_eq!(result.shape(), &[8]);
84/// ```
85pub struct FftPlan {
86    forward: Arc<dyn Fft<f64>>,
87    inverse: Arc<dyn Fft<f64>>,
88    size: usize,
89}
90
91// Arc<dyn Fft<f64>> is Send + Sync because rustfft plans are thread-safe
92unsafe impl Send for FftPlan {}
93unsafe impl Sync for FftPlan {}
94
95impl FftPlan {
96    /// Create a new FFT plan for the given transform size.
97    ///
98    /// The plan pre-computes the internal FFT algorithm so that
99    /// subsequent calls to [`execute`](Self::execute) and
100    /// [`execute_inverse`](Self::execute_inverse) are fast.
101    ///
102    /// # Errors
103    /// Returns `FerrayError::InvalidValue` if `size` is 0.
104    pub fn new(size: usize) -> FerrayResult<Self> {
105        if size == 0 {
106            return Err(FerrayError::invalid_value("FFT plan size must be > 0"));
107        }
108        let forward = get_cached_plan(size, false);
109        let inverse = get_cached_plan(size, true);
110        Ok(Self {
111            forward,
112            inverse,
113            size,
114        })
115    }
116
117    /// Return the transform size this plan was created for.
118    pub fn size(&self) -> usize {
119        self.size
120    }
121
122    /// Execute a forward FFT on the given signal.
123    ///
124    /// The input array must have exactly `self.size()` elements.
125    /// Uses `FftNorm::Backward` (no scaling on forward).
126    ///
127    /// # Errors
128    /// Returns `FerrayError::ShapeMismatch` if the input length
129    /// does not match the plan size.
130    pub fn execute(
131        &self,
132        signal: &Array<Complex<f64>, Ix1>,
133    ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
134        self.execute_with_norm(signal, FftNorm::Backward)
135    }
136
137    /// Execute a forward FFT with the specified normalization.
138    ///
139    /// # Errors
140    /// Returns `FerrayError::ShapeMismatch` if the input length
141    /// does not match the plan size.
142    pub fn execute_with_norm(
143        &self,
144        signal: &Array<Complex<f64>, Ix1>,
145        norm: FftNorm,
146    ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
147        if signal.size() != self.size {
148            return Err(FerrayError::shape_mismatch(format!(
149                "signal length {} does not match plan size {}",
150                signal.size(),
151                self.size,
152            )));
153        }
154        let mut buffer: Vec<Complex<f64>> = signal.iter().copied().collect();
155        let mut scratch = vec![Complex::new(0.0, 0.0); self.forward.get_inplace_scratch_len()];
156        self.forward.process_with_scratch(&mut buffer, &mut scratch);
157
158        let scale = norm.scale_factor(self.size, FftDirection::Forward);
159        if (scale - 1.0).abs() > f64::EPSILON {
160            for c in &mut buffer {
161                *c *= scale;
162            }
163        }
164
165        Array::from_vec(Ix1::new([self.size]), buffer)
166    }
167
168    /// Execute an inverse FFT on the given spectrum.
169    ///
170    /// Uses `FftNorm::Backward` (divides by `n` on inverse).
171    ///
172    /// # Errors
173    /// Returns `FerrayError::ShapeMismatch` if the input length
174    /// does not match the plan size.
175    pub fn execute_inverse(
176        &self,
177        spectrum: &Array<Complex<f64>, Ix1>,
178    ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
179        self.execute_inverse_with_norm(spectrum, FftNorm::Backward)
180    }
181
182    /// Execute an inverse FFT with the specified normalization.
183    ///
184    /// # Errors
185    /// Returns `FerrayError::ShapeMismatch` if the input length
186    /// does not match the plan size.
187    pub fn execute_inverse_with_norm(
188        &self,
189        spectrum: &Array<Complex<f64>, Ix1>,
190        norm: FftNorm,
191    ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
192        if spectrum.size() != self.size {
193            return Err(FerrayError::shape_mismatch(format!(
194                "spectrum length {} does not match plan size {}",
195                spectrum.size(),
196                self.size,
197            )));
198        }
199        let mut buffer: Vec<Complex<f64>> = spectrum.iter().copied().collect();
200        let mut scratch = vec![Complex::new(0.0, 0.0); self.inverse.get_inplace_scratch_len()];
201        self.inverse.process_with_scratch(&mut buffer, &mut scratch);
202
203        let scale = norm.scale_factor(self.size, FftDirection::Inverse);
204        if (scale - 1.0).abs() > f64::EPSILON {
205            for c in &mut buffer {
206                *c *= scale;
207            }
208        }
209
210        Array::from_vec(Ix1::new([self.size]), buffer)
211    }
212}
213
214impl std::fmt::Debug for FftPlan {
215    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216        f.debug_struct("FftPlan").field("size", &self.size).finish()
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn plan_new_valid() {
226        let plan = FftPlan::new(8).unwrap();
227        assert_eq!(plan.size(), 8);
228    }
229
230    #[test]
231    fn plan_new_zero_errors() {
232        assert!(FftPlan::new(0).is_err());
233    }
234
235    #[test]
236    fn plan_execute_roundtrip() {
237        let plan = FftPlan::new(4).unwrap();
238        let data = vec![
239            Complex::new(1.0, 0.0),
240            Complex::new(2.0, 0.0),
241            Complex::new(3.0, 0.0),
242            Complex::new(4.0, 0.0),
243        ];
244        let signal = Array::<Complex<f64>, Ix1>::from_vec(Ix1::new([4]), data.clone()).unwrap();
245
246        let spectrum = plan.execute(&signal).unwrap();
247        let recovered = plan.execute_inverse(&spectrum).unwrap();
248
249        for (orig, rec) in data.iter().zip(recovered.iter()) {
250            assert!((orig.re - rec.re).abs() < 1e-12);
251            assert!((orig.im - rec.im).abs() < 1e-12);
252        }
253    }
254
255    #[test]
256    fn plan_size_mismatch() {
257        let plan = FftPlan::new(8).unwrap();
258        let signal =
259            Array::<Complex<f64>, Ix1>::from_vec(Ix1::new([4]), vec![Complex::new(0.0, 0.0); 4])
260                .unwrap();
261        assert!(plan.execute(&signal).is_err());
262    }
263
264    #[test]
265    fn cached_plan_reuse() {
266        // Getting the same plan twice should return the same Arc
267        let p1 = get_cached_plan(16, false);
268        let p2 = get_cached_plan(16, false);
269        assert!(Arc::ptr_eq(&p1, &p2));
270    }
271
272    #[test]
273    fn plan_is_send_sync() {
274        fn assert_send_sync<T: Send + Sync>() {}
275        assert_send_sync::<FftPlan>();
276    }
277}