1use std::collections::HashMap;
4use std::sync::{Arc, LazyLock, Mutex};
5
6use num_complex::Complex;
7use rustfft::{Fft, FftPlanner};
8
9use ferray_core::error::{FerrayError, FerrayResult};
10use ferray_core::{Array, Ix1};
11
12use crate::norm::{FftDirection, FftNorm};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16struct CacheKey {
17 size: usize,
18 inverse: bool,
19}
20
21static GLOBAL_CACHE: LazyLock<Mutex<PlanCache>> = LazyLock::new(|| Mutex::new(PlanCache::new()));
26
27struct PlanCache {
28 planner: FftPlanner<f64>,
29 plans: HashMap<CacheKey, Arc<dyn Fft<f64>>>,
30}
31
32impl PlanCache {
33 fn new() -> Self {
34 Self {
35 planner: FftPlanner::new(),
36 plans: HashMap::new(),
37 }
38 }
39
40 fn get_plan(&mut self, size: usize, inverse: bool) -> Arc<dyn Fft<f64>> {
41 let key = CacheKey { size, inverse };
42 self.plans
43 .entry(key)
44 .or_insert_with(|| {
45 if inverse {
46 self.planner.plan_fft_inverse(size)
47 } else {
48 self.planner.plan_fft_forward(size)
49 }
50 })
51 .clone()
52 }
53}
54
55pub(crate) fn get_cached_plan(size: usize, inverse: bool) -> Arc<dyn Fft<f64>> {
61 let mut cache = GLOBAL_CACHE.lock().expect("FFT plan cache lock poisoned");
62 cache.get_plan(size, inverse)
63}
64
65pub struct FftPlan {
86 forward: Arc<dyn Fft<f64>>,
87 inverse: Arc<dyn Fft<f64>>,
88 size: usize,
89}
90
91unsafe impl Send for FftPlan {}
93unsafe impl Sync for FftPlan {}
94
95impl FftPlan {
96 pub fn new(size: usize) -> FerrayResult<Self> {
105 if size == 0 {
106 return Err(FerrayError::invalid_value("FFT plan size must be > 0"));
107 }
108 let forward = get_cached_plan(size, false);
109 let inverse = get_cached_plan(size, true);
110 Ok(Self {
111 forward,
112 inverse,
113 size,
114 })
115 }
116
117 pub fn size(&self) -> usize {
119 self.size
120 }
121
122 pub fn execute(
131 &self,
132 signal: &Array<Complex<f64>, Ix1>,
133 ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
134 self.execute_with_norm(signal, FftNorm::Backward)
135 }
136
137 pub fn execute_with_norm(
143 &self,
144 signal: &Array<Complex<f64>, Ix1>,
145 norm: FftNorm,
146 ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
147 if signal.size() != self.size {
148 return Err(FerrayError::shape_mismatch(format!(
149 "signal length {} does not match plan size {}",
150 signal.size(),
151 self.size,
152 )));
153 }
154 let mut buffer: Vec<Complex<f64>> = signal.iter().copied().collect();
155 let mut scratch = vec![Complex::new(0.0, 0.0); self.forward.get_inplace_scratch_len()];
156 self.forward.process_with_scratch(&mut buffer, &mut scratch);
157
158 let scale = norm.scale_factor(self.size, FftDirection::Forward);
159 if (scale - 1.0).abs() > f64::EPSILON {
160 for c in &mut buffer {
161 *c *= scale;
162 }
163 }
164
165 Array::from_vec(Ix1::new([self.size]), buffer)
166 }
167
168 pub fn execute_inverse(
176 &self,
177 spectrum: &Array<Complex<f64>, Ix1>,
178 ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
179 self.execute_inverse_with_norm(spectrum, FftNorm::Backward)
180 }
181
182 pub fn execute_inverse_with_norm(
188 &self,
189 spectrum: &Array<Complex<f64>, Ix1>,
190 norm: FftNorm,
191 ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
192 if spectrum.size() != self.size {
193 return Err(FerrayError::shape_mismatch(format!(
194 "spectrum length {} does not match plan size {}",
195 spectrum.size(),
196 self.size,
197 )));
198 }
199 let mut buffer: Vec<Complex<f64>> = spectrum.iter().copied().collect();
200 let mut scratch = vec![Complex::new(0.0, 0.0); self.inverse.get_inplace_scratch_len()];
201 self.inverse.process_with_scratch(&mut buffer, &mut scratch);
202
203 let scale = norm.scale_factor(self.size, FftDirection::Inverse);
204 if (scale - 1.0).abs() > f64::EPSILON {
205 for c in &mut buffer {
206 *c *= scale;
207 }
208 }
209
210 Array::from_vec(Ix1::new([self.size]), buffer)
211 }
212}
213
214impl std::fmt::Debug for FftPlan {
215 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 f.debug_struct("FftPlan").field("size", &self.size).finish()
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn plan_new_valid() {
226 let plan = FftPlan::new(8).unwrap();
227 assert_eq!(plan.size(), 8);
228 }
229
230 #[test]
231 fn plan_new_zero_errors() {
232 assert!(FftPlan::new(0).is_err());
233 }
234
235 #[test]
236 fn plan_execute_roundtrip() {
237 let plan = FftPlan::new(4).unwrap();
238 let data = vec![
239 Complex::new(1.0, 0.0),
240 Complex::new(2.0, 0.0),
241 Complex::new(3.0, 0.0),
242 Complex::new(4.0, 0.0),
243 ];
244 let signal = Array::<Complex<f64>, Ix1>::from_vec(Ix1::new([4]), data.clone()).unwrap();
245
246 let spectrum = plan.execute(&signal).unwrap();
247 let recovered = plan.execute_inverse(&spectrum).unwrap();
248
249 for (orig, rec) in data.iter().zip(recovered.iter()) {
250 assert!((orig.re - rec.re).abs() < 1e-12);
251 assert!((orig.im - rec.im).abs() < 1e-12);
252 }
253 }
254
255 #[test]
256 fn plan_size_mismatch() {
257 let plan = FftPlan::new(8).unwrap();
258 let signal =
259 Array::<Complex<f64>, Ix1>::from_vec(Ix1::new([4]), vec![Complex::new(0.0, 0.0); 4])
260 .unwrap();
261 assert!(plan.execute(&signal).is_err());
262 }
263
264 #[test]
265 fn cached_plan_reuse() {
266 let p1 = get_cached_plan(16, false);
268 let p2 = get_cached_plan(16, false);
269 assert!(Arc::ptr_eq(&p1, &p2));
270 }
271
272 #[test]
273 fn plan_is_send_sync() {
274 fn assert_send_sync<T: Send + Sync>() {}
275 assert_send_sync::<FftPlan>();
276 }
277}