Skip to main content

ferray_fft/
plan.rs

1// ferray-fft: FftPlan type and global plan cache (REQ-12, REQ-13)
2
3use std::collections::HashMap;
4use std::sync::{Arc, LazyLock, Mutex};
5
6use num_complex::Complex;
7use rustfft::{Fft, FftPlanner};
8
9use ferray_core::error::{FerrayError, FerrayResult};
10use ferray_core::{Array, Ix1};
11
12use crate::norm::{FftDirection, FftNorm};
13
14/// Key for the global FFT plan cache: (transform size, is_inverse).
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16struct CacheKey {
17    size: usize,
18    inverse: bool,
19}
20
21/// Global plan cache: maps (size, direction) to a reusable FFT plan.
22///
23/// Thread-safe via `Mutex`. Plans are `Arc`-wrapped so they can be
24/// shared across threads without copying.
25static GLOBAL_CACHE: LazyLock<Mutex<PlanCache>> = LazyLock::new(|| Mutex::new(PlanCache::new()));
26
27struct PlanCache {
28    planner: FftPlanner<f64>,
29    plans: HashMap<CacheKey, Arc<dyn Fft<f64>>>,
30}
31
32impl PlanCache {
33    fn new() -> Self {
34        Self {
35            planner: FftPlanner::new(),
36            plans: HashMap::new(),
37        }
38    }
39
40    fn get_plan(&mut self, size: usize, inverse: bool) -> Arc<dyn Fft<f64>> {
41        let key = CacheKey { size, inverse };
42        self.plans
43            .entry(key)
44            .or_insert_with(|| {
45                if inverse {
46                    self.planner.plan_fft_inverse(size)
47                } else {
48                    self.planner.plan_fft_forward(size)
49                }
50            })
51            .clone()
52    }
53}
54
55/// Obtain a cached FFT plan for the given size and direction.
56///
57/// This is the primary internal entry point used by all FFT functions.
58/// Plans are cached globally so repeated transforms of the same size
59/// reuse the same plan.
60pub(crate) fn get_cached_plan(size: usize, inverse: bool) -> Arc<dyn Fft<f64>> {
61    let mut cache = GLOBAL_CACHE.lock().expect("FFT plan cache lock poisoned");
62    cache.get_plan(size, inverse)
63}
64
65/// A reusable FFT plan for a specific transform size.
66///
67/// `FftPlan` caches the internal FFT algorithm for a given size,
68/// enabling efficient repeated transforms. Plans are `Send + Sync`
69/// and can be shared across threads.
70///
71/// # Example
72/// ```
73/// use ferray_fft::FftPlan;
74/// use ferray_core::{Array, Ix1};
75/// use num_complex::Complex;
76///
77/// let plan = FftPlan::new(8).unwrap();
78/// let signal = Array::<Complex<f64>, Ix1>::from_vec(
79///     Ix1::new([8]),
80///     vec![Complex::new(1.0, 0.0); 8],
81/// ).unwrap();
82/// let result = plan.execute(&signal).unwrap();
83/// assert_eq!(result.shape(), &[8]);
84/// ```
85pub struct FftPlan {
86    forward: Arc<dyn Fft<f64>>,
87    inverse: Arc<dyn Fft<f64>>,
88    size: usize,
89}
90
91// FftPlan is Send + Sync because Arc<dyn Fft<f64>> is Send + Sync
92// (rustfft plans are thread-safe). No manual unsafe impl needed.
93
94impl FftPlan {
95    /// Create a new FFT plan for the given transform size.
96    ///
97    /// The plan pre-computes the internal FFT algorithm so that
98    /// subsequent calls to [`execute`](Self::execute) and
99    /// [`execute_inverse`](Self::execute_inverse) are fast.
100    ///
101    /// # Errors
102    /// Returns `FerrayError::InvalidValue` if `size` is 0.
103    pub fn new(size: usize) -> FerrayResult<Self> {
104        if size == 0 {
105            return Err(FerrayError::invalid_value("FFT plan size must be > 0"));
106        }
107        let forward = get_cached_plan(size, false);
108        let inverse = get_cached_plan(size, true);
109        Ok(Self {
110            forward,
111            inverse,
112            size,
113        })
114    }
115
116    /// Return the transform size this plan was created for.
117    pub fn size(&self) -> usize {
118        self.size
119    }
120
121    /// Execute a forward FFT on the given signal.
122    ///
123    /// The input array must have exactly `self.size()` elements.
124    /// Uses `FftNorm::Backward` (no scaling on forward).
125    ///
126    /// # Errors
127    /// Returns `FerrayError::ShapeMismatch` if the input length
128    /// does not match the plan size.
129    pub fn execute(
130        &self,
131        signal: &Array<Complex<f64>, Ix1>,
132    ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
133        self.execute_with_norm(signal, FftNorm::Backward)
134    }
135
136    /// Execute a forward FFT with the specified normalization.
137    ///
138    /// # Errors
139    /// Returns `FerrayError::ShapeMismatch` if the input length
140    /// does not match the plan size.
141    pub fn execute_with_norm(
142        &self,
143        signal: &Array<Complex<f64>, Ix1>,
144        norm: FftNorm,
145    ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
146        if signal.size() != self.size {
147            return Err(FerrayError::shape_mismatch(format!(
148                "signal length {} does not match plan size {}",
149                signal.size(),
150                self.size,
151            )));
152        }
153        let mut buffer: Vec<Complex<f64>> = signal.iter().copied().collect();
154        let mut scratch = vec![Complex::new(0.0, 0.0); self.forward.get_inplace_scratch_len()];
155        self.forward.process_with_scratch(&mut buffer, &mut scratch);
156
157        let scale = norm.scale_factor(self.size, FftDirection::Forward);
158        if (scale - 1.0).abs() > f64::EPSILON {
159            for c in &mut buffer {
160                *c *= scale;
161            }
162        }
163
164        Array::from_vec(Ix1::new([self.size]), buffer)
165    }
166
167    /// Execute an inverse FFT on the given spectrum.
168    ///
169    /// Uses `FftNorm::Backward` (divides by `n` on inverse).
170    ///
171    /// # Errors
172    /// Returns `FerrayError::ShapeMismatch` if the input length
173    /// does not match the plan size.
174    pub fn execute_inverse(
175        &self,
176        spectrum: &Array<Complex<f64>, Ix1>,
177    ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
178        self.execute_inverse_with_norm(spectrum, FftNorm::Backward)
179    }
180
181    /// Execute an inverse FFT with the specified normalization.
182    ///
183    /// # Errors
184    /// Returns `FerrayError::ShapeMismatch` if the input length
185    /// does not match the plan size.
186    pub fn execute_inverse_with_norm(
187        &self,
188        spectrum: &Array<Complex<f64>, Ix1>,
189        norm: FftNorm,
190    ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
191        if spectrum.size() != self.size {
192            return Err(FerrayError::shape_mismatch(format!(
193                "spectrum length {} does not match plan size {}",
194                spectrum.size(),
195                self.size,
196            )));
197        }
198        let mut buffer: Vec<Complex<f64>> = spectrum.iter().copied().collect();
199        let mut scratch = vec![Complex::new(0.0, 0.0); self.inverse.get_inplace_scratch_len()];
200        self.inverse.process_with_scratch(&mut buffer, &mut scratch);
201
202        let scale = norm.scale_factor(self.size, FftDirection::Inverse);
203        if (scale - 1.0).abs() > f64::EPSILON {
204            for c in &mut buffer {
205                *c *= scale;
206            }
207        }
208
209        Array::from_vec(Ix1::new([self.size]), buffer)
210    }
211}
212
213impl std::fmt::Debug for FftPlan {
214    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215        f.debug_struct("FftPlan").field("size", &self.size).finish()
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn plan_new_valid() {
225        let plan = FftPlan::new(8).unwrap();
226        assert_eq!(plan.size(), 8);
227    }
228
229    #[test]
230    fn plan_new_zero_errors() {
231        assert!(FftPlan::new(0).is_err());
232    }
233
234    #[test]
235    fn plan_execute_roundtrip() {
236        let plan = FftPlan::new(4).unwrap();
237        let data = vec![
238            Complex::new(1.0, 0.0),
239            Complex::new(2.0, 0.0),
240            Complex::new(3.0, 0.0),
241            Complex::new(4.0, 0.0),
242        ];
243        let signal = Array::<Complex<f64>, Ix1>::from_vec(Ix1::new([4]), data.clone()).unwrap();
244
245        let spectrum = plan.execute(&signal).unwrap();
246        let recovered = plan.execute_inverse(&spectrum).unwrap();
247
248        for (orig, rec) in data.iter().zip(recovered.iter()) {
249            assert!((orig.re - rec.re).abs() < 1e-12);
250            assert!((orig.im - rec.im).abs() < 1e-12);
251        }
252    }
253
254    #[test]
255    fn plan_size_mismatch() {
256        let plan = FftPlan::new(8).unwrap();
257        let signal =
258            Array::<Complex<f64>, Ix1>::from_vec(Ix1::new([4]), vec![Complex::new(0.0, 0.0); 4])
259                .unwrap();
260        assert!(plan.execute(&signal).is_err());
261    }
262
263    #[test]
264    fn cached_plan_reuse() {
265        // Getting the same plan twice should return the same Arc
266        let p1 = get_cached_plan(16, false);
267        let p2 = get_cached_plan(16, false);
268        assert!(Arc::ptr_eq(&p1, &p2));
269    }
270
271    #[test]
272    fn plan_is_send_sync() {
273        fn assert_send_sync<T: Send + Sync>() {}
274        assert_send_sync::<FftPlan>();
275    }
276}