1use std::collections::HashMap;
4use std::sync::{Arc, LazyLock, Mutex};
5
6use num_complex::Complex;
7use rustfft::{Fft, FftPlanner};
8
9use ferray_core::error::{FerrayError, FerrayResult};
10use ferray_core::{Array, Ix1};
11
12use crate::norm::{FftDirection, FftNorm};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16struct CacheKey {
17 size: usize,
18 inverse: bool,
19}
20
21static GLOBAL_CACHE: LazyLock<Mutex<PlanCache>> = LazyLock::new(|| Mutex::new(PlanCache::new()));
26
27struct PlanCache {
28 planner: FftPlanner<f64>,
29 plans: HashMap<CacheKey, Arc<dyn Fft<f64>>>,
30}
31
32impl PlanCache {
33 fn new() -> Self {
34 Self {
35 planner: FftPlanner::new(),
36 plans: HashMap::new(),
37 }
38 }
39
40 fn get_plan(&mut self, size: usize, inverse: bool) -> Arc<dyn Fft<f64>> {
41 let key = CacheKey { size, inverse };
42 self.plans
43 .entry(key)
44 .or_insert_with(|| {
45 if inverse {
46 self.planner.plan_fft_inverse(size)
47 } else {
48 self.planner.plan_fft_forward(size)
49 }
50 })
51 .clone()
52 }
53}
54
55pub(crate) fn get_cached_plan(size: usize, inverse: bool) -> Arc<dyn Fft<f64>> {
61 let mut cache = GLOBAL_CACHE.lock().expect("FFT plan cache lock poisoned");
62 cache.get_plan(size, inverse)
63}
64
65pub struct FftPlan {
86 forward: Arc<dyn Fft<f64>>,
87 inverse: Arc<dyn Fft<f64>>,
88 size: usize,
89}
90
91impl FftPlan {
95 pub fn new(size: usize) -> FerrayResult<Self> {
104 if size == 0 {
105 return Err(FerrayError::invalid_value("FFT plan size must be > 0"));
106 }
107 let forward = get_cached_plan(size, false);
108 let inverse = get_cached_plan(size, true);
109 Ok(Self {
110 forward,
111 inverse,
112 size,
113 })
114 }
115
116 pub fn size(&self) -> usize {
118 self.size
119 }
120
121 pub fn execute(
130 &self,
131 signal: &Array<Complex<f64>, Ix1>,
132 ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
133 self.execute_with_norm(signal, FftNorm::Backward)
134 }
135
136 pub fn execute_with_norm(
142 &self,
143 signal: &Array<Complex<f64>, Ix1>,
144 norm: FftNorm,
145 ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
146 if signal.size() != self.size {
147 return Err(FerrayError::shape_mismatch(format!(
148 "signal length {} does not match plan size {}",
149 signal.size(),
150 self.size,
151 )));
152 }
153 let mut buffer: Vec<Complex<f64>> = signal.iter().copied().collect();
154 let mut scratch = vec![Complex::new(0.0, 0.0); self.forward.get_inplace_scratch_len()];
155 self.forward.process_with_scratch(&mut buffer, &mut scratch);
156
157 let scale = norm.scale_factor(self.size, FftDirection::Forward);
158 if (scale - 1.0).abs() > f64::EPSILON {
159 for c in &mut buffer {
160 *c *= scale;
161 }
162 }
163
164 Array::from_vec(Ix1::new([self.size]), buffer)
165 }
166
167 pub fn execute_inverse(
175 &self,
176 spectrum: &Array<Complex<f64>, Ix1>,
177 ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
178 self.execute_inverse_with_norm(spectrum, FftNorm::Backward)
179 }
180
181 pub fn execute_inverse_with_norm(
187 &self,
188 spectrum: &Array<Complex<f64>, Ix1>,
189 norm: FftNorm,
190 ) -> FerrayResult<Array<Complex<f64>, Ix1>> {
191 if spectrum.size() != self.size {
192 return Err(FerrayError::shape_mismatch(format!(
193 "spectrum length {} does not match plan size {}",
194 spectrum.size(),
195 self.size,
196 )));
197 }
198 let mut buffer: Vec<Complex<f64>> = spectrum.iter().copied().collect();
199 let mut scratch = vec![Complex::new(0.0, 0.0); self.inverse.get_inplace_scratch_len()];
200 self.inverse.process_with_scratch(&mut buffer, &mut scratch);
201
202 let scale = norm.scale_factor(self.size, FftDirection::Inverse);
203 if (scale - 1.0).abs() > f64::EPSILON {
204 for c in &mut buffer {
205 *c *= scale;
206 }
207 }
208
209 Array::from_vec(Ix1::new([self.size]), buffer)
210 }
211}
212
213impl std::fmt::Debug for FftPlan {
214 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215 f.debug_struct("FftPlan").field("size", &self.size).finish()
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn plan_new_valid() {
225 let plan = FftPlan::new(8).unwrap();
226 assert_eq!(plan.size(), 8);
227 }
228
229 #[test]
230 fn plan_new_zero_errors() {
231 assert!(FftPlan::new(0).is_err());
232 }
233
234 #[test]
235 fn plan_execute_roundtrip() {
236 let plan = FftPlan::new(4).unwrap();
237 let data = vec![
238 Complex::new(1.0, 0.0),
239 Complex::new(2.0, 0.0),
240 Complex::new(3.0, 0.0),
241 Complex::new(4.0, 0.0),
242 ];
243 let signal = Array::<Complex<f64>, Ix1>::from_vec(Ix1::new([4]), data.clone()).unwrap();
244
245 let spectrum = plan.execute(&signal).unwrap();
246 let recovered = plan.execute_inverse(&spectrum).unwrap();
247
248 for (orig, rec) in data.iter().zip(recovered.iter()) {
249 assert!((orig.re - rec.re).abs() < 1e-12);
250 assert!((orig.im - rec.im).abs() < 1e-12);
251 }
252 }
253
254 #[test]
255 fn plan_size_mismatch() {
256 let plan = FftPlan::new(8).unwrap();
257 let signal =
258 Array::<Complex<f64>, Ix1>::from_vec(Ix1::new([4]), vec![Complex::new(0.0, 0.0); 4])
259 .unwrap();
260 assert!(plan.execute(&signal).is_err());
261 }
262
263 #[test]
264 fn cached_plan_reuse() {
265 let p1 = get_cached_plan(16, false);
267 let p2 = get_cached_plan(16, false);
268 assert!(Arc::ptr_eq(&p1, &p2));
269 }
270
271 #[test]
272 fn plan_is_send_sync() {
273 fn assert_send_sync<T: Send + Sync>() {}
274 assert_send_sync::<FftPlan>();
275 }
276}