Skip to main content

aocl_fft/
lib.rs

1//! Safe wrappers for AOCL-FFTW.
2//!
3//! Provides:
4//! - One-shot 1-D / 2-D / 3-D complex DFTs (in-place) in both `f64`
5//!   (default names) and `f32` (suffixed `_f32`).
6//! - One-shot 1-D real-to-complex (`r2c`) and complex-to-real (`c2r`)
7//!   transforms in both precisions.
8//! - Reusable plan types: [`Plan`] (`f64`) and [`PlanF32`] (`f32`) that
9//!   cache plan creation and support FFTW's "new-array execute"
10//!   routines for repeated transforms over the same shape.
11//!
12//! Plan creation/destruction is serialized through a process-wide
13//! mutex because FFTW's planner is not internally thread-safe in the
14//! single-threaded build we link against. Plan **execution** is
15//! thread-safe and not held under that lock. The same lock guards
16//! both precisions since they share the planner state.
17
18#![warn(missing_debug_implementations)]
19#![cfg_attr(docsrs, feature(doc_cfg))]
20
21pub use aocl_error::{Error, Result};
22use aocl_fft_sys as sys;
23use std::sync::Mutex;
24
25/// FFTW's plan-creation and plan-destruction routines mutate global state.
26static PLANNER_LOCK: Mutex<()> = Mutex::new(());
27
28/// Direction of a complex DFT.
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub enum Direction {
31    /// Forward transform (FFTW_FORWARD = -1).
32    Forward,
33    /// Backward / inverse transform (FFTW_BACKWARD = +1). FFTW does
34    /// **not** divide by `n` — call sites that want a strict inverse
35    /// must scale the output themselves.
36    Backward,
37}
38
39impl Direction {
40    fn raw(self) -> std::os::raw::c_int {
41        match self {
42            Direction::Forward => sys::FFTW_FORWARD,
43            Direction::Backward => sys::FFTW_BACKWARD as i32,
44        }
45    }
46}
47
48fn check_n(name: &str, n: usize) -> Result<i32> {
49    if n > i32::MAX as usize {
50        return Err(Error::InvalidArgument(format!(
51            "{name}: dimension {n} exceeds i32::MAX"
52        )));
53    }
54    Ok(n as i32)
55}
56
57// =========================================================================
58//   One-shot complex DFTs (1-D / 2-D / 3-D)
59// =========================================================================
60
61/// Compute a 1-D complex DFT in place. `data` is treated as `n = data.len()`
62/// complex samples in `[real, imag]` order.
63pub fn dft_1d_inplace(direction: Direction, data: &mut [[f64; 2]]) -> Result<()> {
64    if data.is_empty() {
65        return Ok(());
66    }
67    let n = check_n("dft_1d_inplace", data.len())?;
68    let ptr = data.as_mut_ptr() as *mut sys::fftw_complex;
69    let plan = {
70        let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
71        unsafe { sys::fftw_plan_dft_1d(n, ptr, ptr, direction.raw(), sys::FFTW_ESTIMATE) }
72    };
73    if plan.is_null() {
74        return Err(Error::AllocationFailed("fft"));
75    }
76    unsafe { sys::fftw_execute(plan) };
77    let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
78    unsafe { sys::fftw_destroy_plan(plan) };
79    Ok(())
80}
81
82/// Compute a 2-D complex DFT in place. `data` is `n0 × n1` complex samples
83/// in row-major order (length `n0 · n1`).
84pub fn dft_2d_inplace(
85    direction: Direction,
86    n0: usize,
87    n1: usize,
88    data: &mut [[f64; 2]],
89) -> Result<()> {
90    if n0 == 0 || n1 == 0 {
91        return Ok(());
92    }
93    let need = n0 * n1;
94    if data.len() < need {
95        return Err(Error::InvalidArgument(format!(
96            "dft_2d_inplace: data length {} < n0·n1 = {need}",
97            data.len()
98        )));
99    }
100    let n0 = check_n("dft_2d_inplace: n0", n0)?;
101    let n1 = check_n("dft_2d_inplace: n1", n1)?;
102    let ptr = data.as_mut_ptr() as *mut sys::fftw_complex;
103    let plan = {
104        let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
105        unsafe { sys::fftw_plan_dft_2d(n0, n1, ptr, ptr, direction.raw(), sys::FFTW_ESTIMATE) }
106    };
107    if plan.is_null() {
108        return Err(Error::AllocationFailed("fft"));
109    }
110    unsafe { sys::fftw_execute(plan) };
111    let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
112    unsafe { sys::fftw_destroy_plan(plan) };
113    Ok(())
114}
115
116/// Compute a 3-D complex DFT in place.
117pub fn dft_3d_inplace(
118    direction: Direction,
119    n0: usize,
120    n1: usize,
121    n2: usize,
122    data: &mut [[f64; 2]],
123) -> Result<()> {
124    if n0 == 0 || n1 == 0 || n2 == 0 {
125        return Ok(());
126    }
127    let need = n0 * n1 * n2;
128    if data.len() < need {
129        return Err(Error::InvalidArgument(format!(
130            "dft_3d_inplace: data length {} < n0·n1·n2 = {need}",
131            data.len()
132        )));
133    }
134    let n0 = check_n("dft_3d_inplace: n0", n0)?;
135    let n1 = check_n("dft_3d_inplace: n1", n1)?;
136    let n2 = check_n("dft_3d_inplace: n2", n2)?;
137    let ptr = data.as_mut_ptr() as *mut sys::fftw_complex;
138    let plan = {
139        let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
140        unsafe { sys::fftw_plan_dft_3d(n0, n1, n2, ptr, ptr, direction.raw(), sys::FFTW_ESTIMATE) }
141    };
142    if plan.is_null() {
143        return Err(Error::AllocationFailed("fft"));
144    }
145    unsafe { sys::fftw_execute(plan) };
146    let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
147    unsafe { sys::fftw_destroy_plan(plan) };
148    Ok(())
149}
150
151/// Convenience: forward 1-D DFT in place.
152pub fn forward_inplace(data: &mut [[f64; 2]]) -> Result<()> {
153    dft_1d_inplace(Direction::Forward, data)
154}
155
156/// Convenience: backward (unscaled) 1-D DFT in place.
157pub fn backward_inplace(data: &mut [[f64; 2]]) -> Result<()> {
158    dft_1d_inplace(Direction::Backward, data)
159}
160
161// =========================================================================
162//   Real ↔ complex one-shot (1-D)
163// =========================================================================
164
165/// Compute a 1-D real-to-complex (forward) DFT.
166///
167/// `out` must hold at least `n/2 + 1` complex samples (the second half of
168/// the spectrum is implied by Hermitian symmetry).
169pub fn r2c_1d(input: &mut [f64], output: &mut [[f64; 2]]) -> Result<()> {
170    let n = input.len();
171    if n == 0 {
172        return Ok(());
173    }
174    let need_out = n / 2 + 1;
175    if output.len() < need_out {
176        return Err(Error::InvalidArgument(format!(
177            "r2c_1d: output length {} < n/2+1 = {need_out}",
178            output.len()
179        )));
180    }
181    let n_i = check_n("r2c_1d", n)?;
182    let in_p = input.as_mut_ptr();
183    let out_p = output.as_mut_ptr() as *mut sys::fftw_complex;
184    let plan = {
185        let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
186        unsafe { sys::fftw_plan_dft_r2c_1d(n_i, in_p, out_p, sys::FFTW_ESTIMATE) }
187    };
188    if plan.is_null() {
189        return Err(Error::AllocationFailed("fft"));
190    }
191    unsafe { sys::fftw_execute(plan) };
192    let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
193    unsafe { sys::fftw_destroy_plan(plan) };
194    Ok(())
195}
196
197/// Compute a 1-D complex-to-real (backward) DFT.
198///
199/// `n` is the size of the *output* (the real time-domain signal). `input`
200/// must hold at least `n/2 + 1` complex samples. Note FFTW does **not**
201/// normalize — divide by `n` to recover the original time-domain signal
202/// after a forward+backward round-trip.
203pub fn c2r_1d(n: usize, input: &mut [[f64; 2]], output: &mut [f64]) -> Result<()> {
204    if n == 0 {
205        return Ok(());
206    }
207    let need_in = n / 2 + 1;
208    if input.len() < need_in {
209        return Err(Error::InvalidArgument(format!(
210            "c2r_1d: input length {} < n/2+1 = {need_in}",
211            input.len()
212        )));
213    }
214    if output.len() < n {
215        return Err(Error::InvalidArgument(format!(
216            "c2r_1d: output length {} < n = {n}",
217            output.len()
218        )));
219    }
220    let n_i = check_n("c2r_1d", n)?;
221    let in_p = input.as_mut_ptr() as *mut sys::fftw_complex;
222    let out_p = output.as_mut_ptr();
223    let plan = {
224        let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
225        unsafe { sys::fftw_plan_dft_c2r_1d(n_i, in_p, out_p, sys::FFTW_ESTIMATE) }
226    };
227    if plan.is_null() {
228        return Err(Error::AllocationFailed("fft"));
229    }
230    unsafe { sys::fftw_execute(plan) };
231    let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
232    unsafe { sys::fftw_destroy_plan(plan) };
233    Ok(())
234}
235
236// =========================================================================
237//   Reusable Plan
238// =========================================================================
239
240/// Kind of transform a [`Plan`] performs.
241#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
242enum PlanKind {
243    Complex,
244    R2C,
245    C2R,
246}
247
248/// A reusable FFTW plan over a fixed shape and transform kind.
249///
250/// Construct once with [`Plan::dft_1d`] / [`Plan::dft_2d`] / [`Plan::dft_3d`]
251/// / [`Plan::r2c_1d`] / [`Plan::c2r_1d`], then call
252/// [`Plan::execute_dft`] / [`Plan::execute_r2c`] / [`Plan::execute_c2r`]
253/// repeatedly with new buffers of the matching shape. The buffers used at
254/// execute time must have the same alignment as the buffers used at plan
255/// creation; plain `Vec<…>` and `[…; N]` storage on x86_64 satisfy this.
256pub struct Plan {
257    plan: sys::fftw_plan,
258    kind: PlanKind,
259    /// Total number of complex elements (for dft) or real elements (for r2c/c2r).
260    n_total: usize,
261}
262
263unsafe impl Send for Plan {}
264
265impl Plan {
266    /// Build a 1-D complex DFT plan over `n` samples in the given direction.
267    pub fn dft_1d(n: usize, direction: Direction) -> Result<Self> {
268        if n == 0 {
269            return Err(Error::InvalidArgument("dft_1d: n must be positive".into()));
270        }
271        let n_i = check_n("dft_1d", n)?;
272        // Create a scratch buffer for plan-time pointer requirements;
273        // FFTW_ESTIMATE does not actually run the transform, so the
274        // contents are irrelevant.
275        let mut scratch = vec![[0.0_f64, 0.0_f64]; n];
276        let p = scratch.as_mut_ptr() as *mut sys::fftw_complex;
277        let plan = {
278            let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
279            unsafe { sys::fftw_plan_dft_1d(n_i, p, p, direction.raw(), sys::FFTW_ESTIMATE) }
280        };
281        if plan.is_null() {
282            return Err(Error::AllocationFailed("fft"));
283        }
284        Ok(Self {
285            plan,
286            kind: PlanKind::Complex,
287            n_total: n,
288        })
289    }
290
291    /// Build a 2-D complex DFT plan over `n0 × n1` samples.
292    pub fn dft_2d(n0: usize, n1: usize, direction: Direction) -> Result<Self> {
293        if n0 == 0 || n1 == 0 {
294            return Err(Error::InvalidArgument(
295                "dft_2d: dimensions must be positive".into(),
296            ));
297        }
298        let n0_i = check_n("dft_2d: n0", n0)?;
299        let n1_i = check_n("dft_2d: n1", n1)?;
300        let mut scratch = vec![[0.0_f64, 0.0_f64]; n0 * n1];
301        let p = scratch.as_mut_ptr() as *mut sys::fftw_complex;
302        let plan = {
303            let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
304            unsafe { sys::fftw_plan_dft_2d(n0_i, n1_i, p, p, direction.raw(), sys::FFTW_ESTIMATE) }
305        };
306        if plan.is_null() {
307            return Err(Error::AllocationFailed("fft"));
308        }
309        Ok(Self {
310            plan,
311            kind: PlanKind::Complex,
312            n_total: n0 * n1,
313        })
314    }
315
316    /// Build a 3-D complex DFT plan.
317    pub fn dft_3d(n0: usize, n1: usize, n2: usize, direction: Direction) -> Result<Self> {
318        if n0 == 0 || n1 == 0 || n2 == 0 {
319            return Err(Error::InvalidArgument(
320                "dft_3d: dimensions must be positive".into(),
321            ));
322        }
323        let n0_i = check_n("dft_3d: n0", n0)?;
324        let n1_i = check_n("dft_3d: n1", n1)?;
325        let n2_i = check_n("dft_3d: n2", n2)?;
326        let mut scratch = vec![[0.0_f64, 0.0_f64]; n0 * n1 * n2];
327        let p = scratch.as_mut_ptr() as *mut sys::fftw_complex;
328        let plan = {
329            let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
330            unsafe {
331                sys::fftw_plan_dft_3d(n0_i, n1_i, n2_i, p, p, direction.raw(), sys::FFTW_ESTIMATE)
332            }
333        };
334        if plan.is_null() {
335            return Err(Error::AllocationFailed("fft"));
336        }
337        Ok(Self {
338            plan,
339            kind: PlanKind::Complex,
340            n_total: n0 * n1 * n2,
341        })
342    }
343
344    /// Build a 1-D real-to-complex forward plan.
345    pub fn r2c_1d(n: usize) -> Result<Self> {
346        if n == 0 {
347            return Err(Error::InvalidArgument("r2c_1d: n must be positive".into()));
348        }
349        let n_i = check_n("r2c_1d", n)?;
350        let mut in_buf = vec![0.0_f64; n];
351        let mut out_buf = vec![[0.0_f64, 0.0_f64]; n / 2 + 1];
352        let plan = {
353            let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
354            unsafe {
355                sys::fftw_plan_dft_r2c_1d(
356                    n_i,
357                    in_buf.as_mut_ptr(),
358                    out_buf.as_mut_ptr() as *mut sys::fftw_complex,
359                    sys::FFTW_ESTIMATE,
360                )
361            }
362        };
363        if plan.is_null() {
364            return Err(Error::AllocationFailed("fft"));
365        }
366        Ok(Self {
367            plan,
368            kind: PlanKind::R2C,
369            n_total: n,
370        })
371    }
372
373    /// Build a 1-D complex-to-real backward plan.
374    pub fn c2r_1d(n: usize) -> Result<Self> {
375        if n == 0 {
376            return Err(Error::InvalidArgument("c2r_1d: n must be positive".into()));
377        }
378        let n_i = check_n("c2r_1d", n)?;
379        let mut in_buf = vec![[0.0_f64, 0.0_f64]; n / 2 + 1];
380        let mut out_buf = vec![0.0_f64; n];
381        let plan = {
382            let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
383            unsafe {
384                sys::fftw_plan_dft_c2r_1d(
385                    n_i,
386                    in_buf.as_mut_ptr() as *mut sys::fftw_complex,
387                    out_buf.as_mut_ptr(),
388                    sys::FFTW_ESTIMATE,
389                )
390            }
391        };
392        if plan.is_null() {
393            return Err(Error::AllocationFailed("fft"));
394        }
395        Ok(Self {
396            plan,
397            kind: PlanKind::C2R,
398            n_total: n,
399        })
400    }
401
402    /// Execute a complex DFT plan against new buffers of the matching shape.
403    pub fn execute_dft(&self, input: &mut [[f64; 2]], output: &mut [[f64; 2]]) -> Result<()> {
404        if self.kind != PlanKind::Complex {
405            return Err(Error::InvalidArgument(
406                "execute_dft: plan is not a complex DFT".into(),
407            ));
408        }
409        if input.len() < self.n_total || output.len() < self.n_total {
410            return Err(Error::InvalidArgument(format!(
411                "execute_dft: buffers ({}, {}) smaller than plan size {}",
412                input.len(),
413                output.len(),
414                self.n_total
415            )));
416        }
417        unsafe {
418            sys::fftw_execute_dft(
419                self.plan,
420                input.as_mut_ptr() as *mut sys::fftw_complex,
421                output.as_mut_ptr() as *mut sys::fftw_complex,
422            );
423        }
424        Ok(())
425    }
426
427    /// Execute a real-to-complex plan against new buffers.
428    pub fn execute_r2c(&self, input: &mut [f64], output: &mut [[f64; 2]]) -> Result<()> {
429        if self.kind != PlanKind::R2C {
430            return Err(Error::InvalidArgument(
431                "execute_r2c: plan is not an r2c plan".into(),
432            ));
433        }
434        let need_out = self.n_total / 2 + 1;
435        if input.len() < self.n_total || output.len() < need_out {
436            return Err(Error::InvalidArgument(format!(
437                "execute_r2c: input {} < n={}; output {} < n/2+1={}",
438                input.len(),
439                self.n_total,
440                output.len(),
441                need_out
442            )));
443        }
444        unsafe {
445            sys::fftw_execute_dft_r2c(
446                self.plan,
447                input.as_mut_ptr(),
448                output.as_mut_ptr() as *mut sys::fftw_complex,
449            );
450        }
451        Ok(())
452    }
453
454    /// Execute a complex-to-real plan against new buffers.
455    pub fn execute_c2r(&self, input: &mut [[f64; 2]], output: &mut [f64]) -> Result<()> {
456        if self.kind != PlanKind::C2R {
457            return Err(Error::InvalidArgument(
458                "execute_c2r: plan is not a c2r plan".into(),
459            ));
460        }
461        let need_in = self.n_total / 2 + 1;
462        if input.len() < need_in || output.len() < self.n_total {
463            return Err(Error::InvalidArgument(format!(
464                "execute_c2r: input {} < n/2+1={}; output {} < n={}",
465                input.len(),
466                need_in,
467                output.len(),
468                self.n_total
469            )));
470        }
471        unsafe {
472            sys::fftw_execute_dft_c2r(
473                self.plan,
474                input.as_mut_ptr() as *mut sys::fftw_complex,
475                output.as_mut_ptr(),
476            );
477        }
478        Ok(())
479    }
480}
481
482impl Drop for Plan {
483    fn drop(&mut self) {
484        if !self.plan.is_null() {
485            let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
486            unsafe { sys::fftw_destroy_plan(self.plan) };
487            self.plan = std::ptr::null_mut();
488        }
489    }
490}
491
492impl std::fmt::Debug for Plan {
493    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
494        f.debug_struct("Plan")
495            .field("kind", &self.kind)
496            .field("n_total", &self.n_total)
497            .finish_non_exhaustive()
498    }
499}
500
501// =========================================================================
502//   Single-precision (f32) one-shot complex DFTs
503// =========================================================================
504
505/// Compute a 1-D complex DFT in place over `f32` samples.
506pub fn dft_1d_inplace_f32(direction: Direction, data: &mut [[f32; 2]]) -> Result<()> {
507    if data.is_empty() {
508        return Ok(());
509    }
510    let n = check_n("dft_1d_inplace_f32", data.len())?;
511    let ptr = data.as_mut_ptr() as *mut sys::fftwf_complex;
512    let plan = {
513        let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
514        unsafe { sys::fftwf_plan_dft_1d(n, ptr, ptr, direction.raw(), sys::FFTW_ESTIMATE) }
515    };
516    if plan.is_null() {
517        return Err(Error::AllocationFailed("fft"));
518    }
519    unsafe { sys::fftwf_execute(plan) };
520    let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
521    unsafe { sys::fftwf_destroy_plan(plan) };
522    Ok(())
523}
524
525/// Compute a 2-D complex DFT in place over `f32` samples (row-major).
526pub fn dft_2d_inplace_f32(
527    direction: Direction,
528    n0: usize,
529    n1: usize,
530    data: &mut [[f32; 2]],
531) -> Result<()> {
532    if n0 == 0 || n1 == 0 {
533        return Ok(());
534    }
535    let need = n0 * n1;
536    if data.len() < need {
537        return Err(Error::InvalidArgument(format!(
538            "dft_2d_inplace_f32: data length {} < n0·n1 = {need}",
539            data.len()
540        )));
541    }
542    let n0 = check_n("dft_2d_inplace_f32: n0", n0)?;
543    let n1 = check_n("dft_2d_inplace_f32: n1", n1)?;
544    let ptr = data.as_mut_ptr() as *mut sys::fftwf_complex;
545    let plan = {
546        let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
547        unsafe { sys::fftwf_plan_dft_2d(n0, n1, ptr, ptr, direction.raw(), sys::FFTW_ESTIMATE) }
548    };
549    if plan.is_null() {
550        return Err(Error::AllocationFailed("fft"));
551    }
552    unsafe { sys::fftwf_execute(plan) };
553    let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
554    unsafe { sys::fftwf_destroy_plan(plan) };
555    Ok(())
556}
557
558/// Compute a 3-D complex DFT in place over `f32` samples.
559pub fn dft_3d_inplace_f32(
560    direction: Direction,
561    n0: usize,
562    n1: usize,
563    n2: usize,
564    data: &mut [[f32; 2]],
565) -> Result<()> {
566    if n0 == 0 || n1 == 0 || n2 == 0 {
567        return Ok(());
568    }
569    let need = n0 * n1 * n2;
570    if data.len() < need {
571        return Err(Error::InvalidArgument(format!(
572            "dft_3d_inplace_f32: data length {} < n0·n1·n2 = {need}",
573            data.len()
574        )));
575    }
576    let n0 = check_n("dft_3d_inplace_f32: n0", n0)?;
577    let n1 = check_n("dft_3d_inplace_f32: n1", n1)?;
578    let n2 = check_n("dft_3d_inplace_f32: n2", n2)?;
579    let ptr = data.as_mut_ptr() as *mut sys::fftwf_complex;
580    let plan = {
581        let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
582        unsafe { sys::fftwf_plan_dft_3d(n0, n1, n2, ptr, ptr, direction.raw(), sys::FFTW_ESTIMATE) }
583    };
584    if plan.is_null() {
585        return Err(Error::AllocationFailed("fft"));
586    }
587    unsafe { sys::fftwf_execute(plan) };
588    let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
589    unsafe { sys::fftwf_destroy_plan(plan) };
590    Ok(())
591}
592
593/// Convenience: forward 1-D `f32` DFT in place.
594pub fn forward_inplace_f32(data: &mut [[f32; 2]]) -> Result<()> {
595    dft_1d_inplace_f32(Direction::Forward, data)
596}
597
598/// Convenience: backward (unscaled) 1-D `f32` DFT in place.
599pub fn backward_inplace_f32(data: &mut [[f32; 2]]) -> Result<()> {
600    dft_1d_inplace_f32(Direction::Backward, data)
601}
602
603// =========================================================================
604//   Single-precision (f32) real ↔ complex one-shot (1-D)
605// =========================================================================
606
607/// Compute a 1-D real-to-complex (forward) `f32` DFT.
608///
609/// `output` must hold at least `n/2 + 1` complex samples.
610pub fn r2c_1d_f32(input: &mut [f32], output: &mut [[f32; 2]]) -> Result<()> {
611    let n = input.len();
612    if n == 0 {
613        return Ok(());
614    }
615    let need_out = n / 2 + 1;
616    if output.len() < need_out {
617        return Err(Error::InvalidArgument(format!(
618            "r2c_1d_f32: output length {} < n/2+1 = {need_out}",
619            output.len()
620        )));
621    }
622    let n_i = check_n("r2c_1d_f32", n)?;
623    let in_p = input.as_mut_ptr();
624    let out_p = output.as_mut_ptr() as *mut sys::fftwf_complex;
625    let plan = {
626        let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
627        unsafe { sys::fftwf_plan_dft_r2c_1d(n_i, in_p, out_p, sys::FFTW_ESTIMATE) }
628    };
629    if plan.is_null() {
630        return Err(Error::AllocationFailed("fft"));
631    }
632    unsafe { sys::fftwf_execute(plan) };
633    let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
634    unsafe { sys::fftwf_destroy_plan(plan) };
635    Ok(())
636}
637
638/// Compute a 1-D complex-to-real (backward) `f32` DFT.
639///
640/// `n` is the size of the *output*. As with the f64 version, FFTW does
641/// not normalise — divide by `n` to recover the original after a
642/// forward+backward round-trip.
643pub fn c2r_1d_f32(n: usize, input: &mut [[f32; 2]], output: &mut [f32]) -> Result<()> {
644    if n == 0 {
645        return Ok(());
646    }
647    let need_in = n / 2 + 1;
648    if input.len() < need_in {
649        return Err(Error::InvalidArgument(format!(
650            "c2r_1d_f32: input length {} < n/2+1 = {need_in}",
651            input.len()
652        )));
653    }
654    if output.len() < n {
655        return Err(Error::InvalidArgument(format!(
656            "c2r_1d_f32: output length {} < n = {n}",
657            output.len()
658        )));
659    }
660    let n_i = check_n("c2r_1d_f32", n)?;
661    let in_p = input.as_mut_ptr() as *mut sys::fftwf_complex;
662    let out_p = output.as_mut_ptr();
663    let plan = {
664        let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
665        unsafe { sys::fftwf_plan_dft_c2r_1d(n_i, in_p, out_p, sys::FFTW_ESTIMATE) }
666    };
667    if plan.is_null() {
668        return Err(Error::AllocationFailed("fft"));
669    }
670    unsafe { sys::fftwf_execute(plan) };
671    let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
672    unsafe { sys::fftwf_destroy_plan(plan) };
673    Ok(())
674}
675
676// =========================================================================
677//   Reusable single-precision plan
678// =========================================================================
679
680/// A reusable single-precision FFTW plan over a fixed shape.
681///
682/// Mirrors [`Plan`] but for `f32` / `fftwf_complex` buffers.
683pub struct PlanF32 {
684    plan: sys::fftwf_plan,
685    kind: PlanKind,
686    n_total: usize,
687}
688
689unsafe impl Send for PlanF32 {}
690
691impl PlanF32 {
692    /// Build a 1-D complex DFT plan over `n` samples.
693    pub fn dft_1d(n: usize, direction: Direction) -> Result<Self> {
694        if n == 0 {
695            return Err(Error::InvalidArgument("dft_1d: n must be positive".into()));
696        }
697        let n_i = check_n("dft_1d", n)?;
698        let mut scratch = vec![[0.0_f32, 0.0_f32]; n];
699        let p = scratch.as_mut_ptr() as *mut sys::fftwf_complex;
700        let plan = {
701            let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
702            unsafe { sys::fftwf_plan_dft_1d(n_i, p, p, direction.raw(), sys::FFTW_ESTIMATE) }
703        };
704        if plan.is_null() {
705            return Err(Error::AllocationFailed("fft"));
706        }
707        Ok(Self {
708            plan,
709            kind: PlanKind::Complex,
710            n_total: n,
711        })
712    }
713
714    /// Build a 2-D complex DFT plan over `n0 × n1` samples.
715    pub fn dft_2d(n0: usize, n1: usize, direction: Direction) -> Result<Self> {
716        if n0 == 0 || n1 == 0 {
717            return Err(Error::InvalidArgument(
718                "dft_2d: dimensions must be positive".into(),
719            ));
720        }
721        let n0_i = check_n("dft_2d: n0", n0)?;
722        let n1_i = check_n("dft_2d: n1", n1)?;
723        let mut scratch = vec![[0.0_f32, 0.0_f32]; n0 * n1];
724        let p = scratch.as_mut_ptr() as *mut sys::fftwf_complex;
725        let plan = {
726            let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
727            unsafe { sys::fftwf_plan_dft_2d(n0_i, n1_i, p, p, direction.raw(), sys::FFTW_ESTIMATE) }
728        };
729        if plan.is_null() {
730            return Err(Error::AllocationFailed("fft"));
731        }
732        Ok(Self {
733            plan,
734            kind: PlanKind::Complex,
735            n_total: n0 * n1,
736        })
737    }
738
739    /// Build a 3-D complex DFT plan.
740    pub fn dft_3d(n0: usize, n1: usize, n2: usize, direction: Direction) -> Result<Self> {
741        if n0 == 0 || n1 == 0 || n2 == 0 {
742            return Err(Error::InvalidArgument(
743                "dft_3d: dimensions must be positive".into(),
744            ));
745        }
746        let n0_i = check_n("dft_3d: n0", n0)?;
747        let n1_i = check_n("dft_3d: n1", n1)?;
748        let n2_i = check_n("dft_3d: n2", n2)?;
749        let mut scratch = vec![[0.0_f32, 0.0_f32]; n0 * n1 * n2];
750        let p = scratch.as_mut_ptr() as *mut sys::fftwf_complex;
751        let plan = {
752            let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
753            unsafe {
754                sys::fftwf_plan_dft_3d(n0_i, n1_i, n2_i, p, p, direction.raw(), sys::FFTW_ESTIMATE)
755            }
756        };
757        if plan.is_null() {
758            return Err(Error::AllocationFailed("fft"));
759        }
760        Ok(Self {
761            plan,
762            kind: PlanKind::Complex,
763            n_total: n0 * n1 * n2,
764        })
765    }
766
767    /// Build a 1-D real-to-complex forward plan.
768    pub fn r2c_1d(n: usize) -> Result<Self> {
769        if n == 0 {
770            return Err(Error::InvalidArgument("r2c_1d: n must be positive".into()));
771        }
772        let n_i = check_n("r2c_1d", n)?;
773        let mut in_buf = vec![0.0_f32; n];
774        let mut out_buf = vec![[0.0_f32, 0.0_f32]; n / 2 + 1];
775        let plan = {
776            let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
777            unsafe {
778                sys::fftwf_plan_dft_r2c_1d(
779                    n_i,
780                    in_buf.as_mut_ptr(),
781                    out_buf.as_mut_ptr() as *mut sys::fftwf_complex,
782                    sys::FFTW_ESTIMATE,
783                )
784            }
785        };
786        if plan.is_null() {
787            return Err(Error::AllocationFailed("fft"));
788        }
789        Ok(Self {
790            plan,
791            kind: PlanKind::R2C,
792            n_total: n,
793        })
794    }
795
796    /// Build a 1-D complex-to-real backward plan.
797    pub fn c2r_1d(n: usize) -> Result<Self> {
798        if n == 0 {
799            return Err(Error::InvalidArgument("c2r_1d: n must be positive".into()));
800        }
801        let n_i = check_n("c2r_1d", n)?;
802        let mut in_buf = vec![[0.0_f32, 0.0_f32]; n / 2 + 1];
803        let mut out_buf = vec![0.0_f32; n];
804        let plan = {
805            let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
806            unsafe {
807                sys::fftwf_plan_dft_c2r_1d(
808                    n_i,
809                    in_buf.as_mut_ptr() as *mut sys::fftwf_complex,
810                    out_buf.as_mut_ptr(),
811                    sys::FFTW_ESTIMATE,
812                )
813            }
814        };
815        if plan.is_null() {
816            return Err(Error::AllocationFailed("fft"));
817        }
818        Ok(Self {
819            plan,
820            kind: PlanKind::C2R,
821            n_total: n,
822        })
823    }
824
825    /// Execute a complex DFT plan against new buffers.
826    pub fn execute_dft(&self, input: &mut [[f32; 2]], output: &mut [[f32; 2]]) -> Result<()> {
827        if self.kind != PlanKind::Complex {
828            return Err(Error::InvalidArgument(
829                "execute_dft: plan is not a complex DFT".into(),
830            ));
831        }
832        if input.len() < self.n_total || output.len() < self.n_total {
833            return Err(Error::InvalidArgument(format!(
834                "execute_dft: buffers ({}, {}) smaller than plan size {}",
835                input.len(),
836                output.len(),
837                self.n_total
838            )));
839        }
840        unsafe {
841            sys::fftwf_execute_dft(
842                self.plan,
843                input.as_mut_ptr() as *mut sys::fftwf_complex,
844                output.as_mut_ptr() as *mut sys::fftwf_complex,
845            );
846        }
847        Ok(())
848    }
849
850    /// Execute a real-to-complex plan against new buffers.
851    pub fn execute_r2c(&self, input: &mut [f32], output: &mut [[f32; 2]]) -> Result<()> {
852        if self.kind != PlanKind::R2C {
853            return Err(Error::InvalidArgument(
854                "execute_r2c: plan is not an r2c plan".into(),
855            ));
856        }
857        let need_out = self.n_total / 2 + 1;
858        if input.len() < self.n_total || output.len() < need_out {
859            return Err(Error::InvalidArgument(format!(
860                "execute_r2c: input {} < n={}; output {} < n/2+1={}",
861                input.len(),
862                self.n_total,
863                output.len(),
864                need_out
865            )));
866        }
867        unsafe {
868            sys::fftwf_execute_dft_r2c(
869                self.plan,
870                input.as_mut_ptr(),
871                output.as_mut_ptr() as *mut sys::fftwf_complex,
872            );
873        }
874        Ok(())
875    }
876
877    /// Execute a complex-to-real plan against new buffers.
878    pub fn execute_c2r(&self, input: &mut [[f32; 2]], output: &mut [f32]) -> Result<()> {
879        if self.kind != PlanKind::C2R {
880            return Err(Error::InvalidArgument(
881                "execute_c2r: plan is not a c2r plan".into(),
882            ));
883        }
884        let need_in = self.n_total / 2 + 1;
885        if input.len() < need_in || output.len() < self.n_total {
886            return Err(Error::InvalidArgument(format!(
887                "execute_c2r: input {} < n/2+1={}; output {} < n={}",
888                input.len(),
889                need_in,
890                output.len(),
891                self.n_total
892            )));
893        }
894        unsafe {
895            sys::fftwf_execute_dft_c2r(
896                self.plan,
897                input.as_mut_ptr() as *mut sys::fftwf_complex,
898                output.as_mut_ptr(),
899            );
900        }
901        Ok(())
902    }
903}
904
905impl Drop for PlanF32 {
906    fn drop(&mut self) {
907        if !self.plan.is_null() {
908            let _g = PLANNER_LOCK.lock().unwrap_or_else(|e| e.into_inner());
909            unsafe { sys::fftwf_destroy_plan(self.plan) };
910            self.plan = std::ptr::null_mut();
911        }
912    }
913}
914
915impl std::fmt::Debug for PlanF32 {
916    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
917        f.debug_struct("PlanF32")
918            .field("kind", &self.kind)
919            .field("n_total", &self.n_total)
920            .finish_non_exhaustive()
921    }
922}
923
924#[cfg(test)]
925mod tests {
926    use super::*;
927
928    #[test]
929    fn dft_inverse_recovers_input() {
930        let original: [[f64; 2]; 4] = [[1.0, 0.0], [2.0, 0.0], [3.0, 0.0], [4.0, 0.0]];
931        let mut buf = original;
932        forward_inplace(&mut buf).unwrap();
933        backward_inplace(&mut buf).unwrap();
934        let n = original.len() as f64;
935        for (got, orig) in buf.iter().zip(original.iter()) {
936            assert!((got[0] / n - orig[0]).abs() < 1e-12);
937            assert!((got[1] / n - orig[1]).abs() < 1e-12);
938        }
939    }
940
941    #[test]
942    fn dc_signal_concentrates_at_dc() {
943        let n = 8;
944        let c = 1.5_f64;
945        let mut buf: Vec<[f64; 2]> = (0..n).map(|_| [c, 0.0]).collect();
946        forward_inplace(&mut buf).unwrap();
947        let n_f = n as f64;
948        assert!((buf[0][0] - c * n_f).abs() < 1e-12);
949        assert!(buf[0][1].abs() < 1e-12);
950        for v in &buf[1..] {
951            assert!(v[0].abs() < 1e-12);
952            assert!(v[1].abs() < 1e-12);
953        }
954    }
955
956    #[test]
957    fn empty_input_is_ok() {
958        let mut empty: [[f64; 2]; 0] = [];
959        forward_inplace(&mut empty).unwrap();
960    }
961
962    #[test]
963    fn dft_2d_dc() {
964        // 4×4 DC signal → spectrum has all energy at (0,0).
965        let n0 = 4;
966        let n1 = 4;
967        let mut buf: Vec<[f64; 2]> = vec![[1.0, 0.0]; n0 * n1];
968        dft_2d_inplace(Direction::Forward, n0, n1, &mut buf).unwrap();
969        let total = (n0 * n1) as f64;
970        assert!((buf[0][0] - total).abs() < 1e-9);
971        assert!(buf[0][1].abs() < 1e-9);
972        for v in &buf[1..] {
973            assert!(v[0].abs() < 1e-9);
974            assert!(v[1].abs() < 1e-9);
975        }
976    }
977
978    #[test]
979    fn r2c_then_c2r_round_trip() {
980        let original = [1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
981        let n = original.len();
982        let mut input = original;
983        let mut spectrum = vec![[0.0_f64, 0.0]; n / 2 + 1];
984        r2c_1d(&mut input, &mut spectrum).unwrap();
985        let mut recovered = vec![0.0_f64; n];
986        c2r_1d(n, &mut spectrum, &mut recovered).unwrap();
987        // FFTW does not normalize; round-trip multiplies by n.
988        for (got, orig) in recovered.iter().zip(original.iter()) {
989            assert!((got / n as f64 - orig).abs() < 1e-10);
990        }
991    }
992
993    #[test]
994    fn plan_executes_repeatedly() {
995        let plan = Plan::dft_1d(4, Direction::Forward).unwrap();
996        let mut a: Vec<[f64; 2]> = vec![[1.0, 0.0]; 4];
997        let mut out_a = vec![[0.0, 0.0]; 4];
998        plan.execute_dft(&mut a, &mut out_a).unwrap();
999        assert!((out_a[0][0] - 4.0).abs() < 1e-9);
1000
1001        let mut b: Vec<[f64; 2]> = vec![[2.0, 0.0]; 4];
1002        let mut out_b = vec![[0.0, 0.0]; 4];
1003        plan.execute_dft(&mut b, &mut out_b).unwrap();
1004        assert!((out_b[0][0] - 8.0).abs() < 1e-9);
1005    }
1006
1007    #[test]
1008    fn plan_kind_mismatch_is_error() {
1009        let plan = Plan::r2c_1d(8).unwrap();
1010        let mut input: Vec<[f64; 2]> = vec![[0.0, 0.0]; 8];
1011        let mut output: Vec<[f64; 2]> = vec![[0.0, 0.0]; 8];
1012        let err = plan.execute_dft(&mut input, &mut output).unwrap_err();
1013        assert!(matches!(err, Error::InvalidArgument(_)));
1014    }
1015
1016    // -------------------------------------------------------------------
1017    //   f32 tests
1018    // -------------------------------------------------------------------
1019
1020    #[test]
1021    fn dft_inverse_recovers_input_f32() {
1022        let original: [[f32; 2]; 4] = [[1.0, 0.0], [2.0, 0.0], [3.0, 0.0], [4.0, 0.0]];
1023        let mut buf = original;
1024        forward_inplace_f32(&mut buf).unwrap();
1025        backward_inplace_f32(&mut buf).unwrap();
1026        let n = original.len() as f32;
1027        for (got, orig) in buf.iter().zip(original.iter()) {
1028            assert!((got[0] / n - orig[0]).abs() < 1e-5);
1029            assert!((got[1] / n - orig[1]).abs() < 1e-5);
1030        }
1031    }
1032
1033    #[test]
1034    fn dc_signal_concentrates_at_dc_f32() {
1035        let n = 8;
1036        let c = 1.5_f32;
1037        let mut buf: Vec<[f32; 2]> = (0..n).map(|_| [c, 0.0]).collect();
1038        forward_inplace_f32(&mut buf).unwrap();
1039        let n_f = n as f32;
1040        assert!((buf[0][0] - c * n_f).abs() < 1e-4);
1041        assert!(buf[0][1].abs() < 1e-4);
1042        for v in &buf[1..] {
1043            assert!(v[0].abs() < 1e-4);
1044            assert!(v[1].abs() < 1e-4);
1045        }
1046    }
1047
1048    #[test]
1049    fn r2c_then_c2r_round_trip_f32() {
1050        let original = [1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1051        let n = original.len();
1052        let mut input = original;
1053        let mut spectrum = vec![[0.0_f32, 0.0]; n / 2 + 1];
1054        r2c_1d_f32(&mut input, &mut spectrum).unwrap();
1055        let mut recovered = vec![0.0_f32; n];
1056        c2r_1d_f32(n, &mut spectrum, &mut recovered).unwrap();
1057        for (got, orig) in recovered.iter().zip(original.iter()) {
1058            assert!((got / n as f32 - orig).abs() < 1e-4);
1059        }
1060    }
1061
1062    #[test]
1063    fn plan_f32_executes_repeatedly() {
1064        let plan = PlanF32::dft_1d(4, Direction::Forward).unwrap();
1065        let mut a: Vec<[f32; 2]> = vec![[1.0, 0.0]; 4];
1066        let mut out_a = vec![[0.0_f32, 0.0]; 4];
1067        plan.execute_dft(&mut a, &mut out_a).unwrap();
1068        assert!((out_a[0][0] - 4.0).abs() < 1e-4);
1069
1070        let mut b: Vec<[f32; 2]> = vec![[2.0, 0.0]; 4];
1071        let mut out_b = vec![[0.0_f32, 0.0]; 4];
1072        plan.execute_dft(&mut b, &mut out_b).unwrap();
1073        assert!((out_b[0][0] - 8.0).abs() < 1e-4);
1074    }
1075}