use crate::api::{export_to_string, import_from_string};
use crate::kernel::{Complex, Float, IoDim, Tensor};
use crate::{Direction, Flags, GuruPlan, Plan, Plan2D, Plan3D, RealPlan};
#[must_use]
pub fn fftw_plan_dft_1d(n: usize, direction: Direction, flags: Flags) -> Option<Plan<f64>> {
Plan::dft_1d(n, direction, flags)
}
#[must_use]
pub fn fftwf_plan_dft_1d(n: usize, direction: Direction, flags: Flags) -> Option<Plan<f32>> {
Plan::dft_1d(n, direction, flags)
}
#[must_use]
pub fn fftw_plan_dft_2d(
n0: usize,
n1: usize,
direction: Direction,
flags: Flags,
) -> Option<Plan2D<f64>> {
Plan::dft_2d(n0, n1, direction, flags)
}
#[must_use]
pub fn fftw_plan_dft_3d(
n0: usize,
n1: usize,
n2: usize,
direction: Direction,
flags: Flags,
) -> Option<Plan3D<f64>> {
Plan::dft_3d(n0, n1, n2, direction, flags)
}
#[must_use]
pub fn fftw_plan_dft_r2c_1d(n: usize, flags: Flags) -> Option<RealPlan<f64>> {
Plan::r2c_1d(n, flags)
}
#[must_use]
pub fn fftw_plan_dft_c2r_1d(n: usize, flags: Flags) -> Option<RealPlan<f64>> {
Plan::c2r_1d(n, flags)
}
#[must_use]
pub fn fftw_plan_many_dft<T: Float>(
rank: usize,
ns: &[usize],
howmany: usize,
direction: Direction,
flags: Flags,
) -> Option<GuruPlan<T>> {
if rank != ns.len() {
return None;
}
if rank == 0 || howmany == 0 {
return None;
}
if ns.contains(&0) {
return None;
}
let transform_dims: Vec<IoDim> = ns.iter().map(|&n| IoDim::contiguous(n)).collect();
let dims = Tensor::new(transform_dims);
let batch_stride = ns.iter().product::<usize>();
let howmany_dims = Tensor::new(vec![IoDim::new(
howmany,
batch_stride as isize,
batch_stride as isize,
)]);
GuruPlan::dft(&dims, &howmany_dims, direction, flags)
}
pub fn fftw_execute<T: Float>(plan: &Plan<T>, input: &[Complex<T>], output: &mut [Complex<T>]) {
plan.execute(input, output);
}
pub fn fftw_destroy_plan<T: Float>(_plan: Plan<T>) {
}
#[must_use]
pub fn fftw_export_wisdom_to_string() -> Option<String> {
Some(export_to_string())
}
pub fn fftw_import_wisdom_from_string(s: &str) -> bool {
import_from_string(s).is_ok()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api::fft;
#[test]
fn test_fftw_plan_dft_1d_some() {
let plan = fftw_plan_dft_1d(256, Direction::Forward, Flags::ESTIMATE);
assert!(plan.is_some());
let plan = plan.unwrap();
assert_eq!(plan.size(), 256);
}
#[test]
fn test_fftw_plan_dft_1d_zero_nop() {
let plan = fftw_plan_dft_1d(0, Direction::Forward, Flags::ESTIMATE);
assert!(plan.is_some());
let plan = plan.unwrap();
assert_eq!(plan.size(), 0);
}
#[test]
fn test_fftwf_plan_dft_1d_some() {
let plan = fftwf_plan_dft_1d(128, Direction::Forward, Flags::ESTIMATE);
assert!(plan.is_some());
let plan = plan.unwrap();
assert_eq!(plan.size(), 128);
}
#[test]
fn test_fftw_plan_dft_2d_some() {
let plan = fftw_plan_dft_2d(8, 16, Direction::Forward, Flags::ESTIMATE);
assert!(plan.is_some());
let plan = plan.unwrap();
assert_eq!(plan.size(), 128);
}
#[test]
fn test_fftw_plan_dft_2d_non_zero_some() {
let plan = fftw_plan_dft_2d(8, 16, Direction::Forward, Flags::ESTIMATE);
assert!(plan.is_some());
}
#[test]
fn test_fftw_plan_dft_3d_some() {
let plan = fftw_plan_dft_3d(4, 4, 4, Direction::Forward, Flags::ESTIMATE);
assert!(plan.is_some());
let plan = plan.unwrap();
assert_eq!(plan.size(), 64);
}
#[test]
fn test_fftw_plan_dft_3d_non_zero_some() {
let plan = fftw_plan_dft_3d(4, 4, 4, Direction::Forward, Flags::ESTIMATE);
assert!(plan.is_some());
}
#[test]
fn test_fftw_plan_dft_r2c_1d_some() {
let plan = fftw_plan_dft_r2c_1d(64, Flags::ESTIMATE);
assert!(plan.is_some());
let plan = plan.unwrap();
assert_eq!(plan.size(), 64);
assert_eq!(plan.complex_size(), 33); }
#[test]
fn test_fftw_plan_dft_r2c_1d_zero_none() {
let plan = fftw_plan_dft_r2c_1d(0, Flags::ESTIMATE);
assert!(plan.is_none());
}
#[test]
fn test_fftw_plan_dft_c2r_1d_some() {
let plan = fftw_plan_dft_c2r_1d(64, Flags::ESTIMATE);
assert!(plan.is_some());
let plan = plan.unwrap();
assert_eq!(plan.size(), 64);
}
#[test]
fn test_fftw_plan_dft_c2r_1d_zero_none() {
let plan = fftw_plan_dft_c2r_1d(0, Flags::ESTIMATE);
assert!(plan.is_none());
}
#[test]
fn test_fftw_execute_matches_fft_1d() {
let n = 32;
let input: Vec<Complex<f64>> = (0..n).map(|i| Complex::new(i as f64, 0.0)).collect();
let reference = fft(&input);
let plan = fftw_plan_dft_1d(n, Direction::Forward, Flags::ESTIMATE).unwrap();
let mut output = vec![Complex::new(0.0, 0.0); n];
fftw_execute(&plan, &input, &mut output);
assert_eq!(output.len(), reference.len());
for (got, exp) in output.iter().zip(reference.iter()) {
let diff_re = (got.re - exp.re).abs();
let diff_im = (got.im - exp.im).abs();
assert!(
diff_re < 1e-9,
"real part mismatch: got={} expected={}",
got.re,
exp.re
);
assert!(
diff_im < 1e-9,
"imag part mismatch: got={} expected={}",
got.im,
exp.im
);
}
}
#[test]
fn test_fftw_execute_backward() {
let n = 16;
let input: Vec<Complex<f64>> = (0..n)
.map(|i| Complex::new((i as f64).cos(), (i as f64).sin()))
.collect();
let plan = fftw_plan_dft_1d(n, Direction::Backward, Flags::ESTIMATE).unwrap();
let mut output = vec![Complex::new(0.0, 0.0); n];
fftw_execute(&plan, &input, &mut output);
let non_zero = output
.iter()
.any(|c| c.re.abs() > 1e-12 || c.im.abs() > 1e-12);
assert!(non_zero, "backward FFT should produce non-zero output");
}
#[test]
fn test_fftw_execute_roundtrip() {
let n = 8usize;
let original: Vec<Complex<f64>> = (0..n).map(|i| Complex::new(i as f64, 0.0)).collect();
let fwd = fftw_plan_dft_1d(n, Direction::Forward, Flags::ESTIMATE).unwrap();
let bwd = fftw_plan_dft_1d(n, Direction::Backward, Flags::ESTIMATE).unwrap();
let mut freq = vec![Complex::new(0.0, 0.0); n];
fftw_execute(&fwd, &original, &mut freq);
let mut recovered = vec![Complex::new(0.0, 0.0); n];
fftw_execute(&bwd, &freq, &mut recovered);
let inv_n = 1.0 / n as f64;
for c in recovered.iter_mut() {
c.re *= inv_n;
c.im *= inv_n;
}
for (got, exp) in recovered.iter().zip(original.iter()) {
let diff = (got.re - exp.re).abs();
assert!(
diff < 1e-9,
"round-trip mismatch at re: got={} expected={}",
got.re,
exp.re
);
}
}
#[test]
fn test_fftw_destroy_plan_does_not_panic() {
let plan = fftw_plan_dft_1d(64, Direction::Forward, Flags::ESTIMATE).unwrap();
fftw_destroy_plan(plan); }
#[test]
fn test_wisdom_roundtrip() {
let exported = fftw_export_wisdom_to_string();
assert!(exported.is_some(), "export should always return Some");
let wisdom_str = exported.unwrap();
assert!(
wisdom_str.contains("oxifft-wisdom"),
"exported string must contain version header"
);
let ok = fftw_import_wisdom_from_string(&wisdom_str);
assert!(ok, "re-importing exported wisdom must succeed");
}
#[test]
fn test_wisdom_import_bad_string_returns_false() {
let ok = fftw_import_wisdom_from_string("not-valid-wisdom-at-all");
assert!(!ok, "invalid wisdom string must return false");
}
#[test]
fn test_fftw_plan_dft_2d_execute() {
let (n0, n1) = (4usize, 4usize);
let total = n0 * n1;
let plan = fftw_plan_dft_2d(n0, n1, Direction::Forward, Flags::ESTIMATE).unwrap();
let input: Vec<Complex<f64>> = (0..total).map(|i| Complex::new(i as f64, 0.0)).collect();
let mut output = vec![Complex::new(0.0, 0.0); total];
plan.execute(&input, &mut output);
let expected_dc: f64 = (0..total).map(|i| i as f64).sum();
let diff = (output[0].re - expected_dc).abs();
assert!(
diff < 1e-9,
"DC bin mismatch: got={} expected={}",
output[0].re,
expected_dc
);
}
#[test]
fn test_fftw_plan_dft_3d_execute() {
let (n0, n1, n2) = (2usize, 2usize, 2usize);
let total = n0 * n1 * n2;
let plan = fftw_plan_dft_3d(n0, n1, n2, Direction::Forward, Flags::ESTIMATE).unwrap();
let input: Vec<Complex<f64>> = (0..total).map(|i| Complex::new(i as f64, 0.0)).collect();
let mut output = vec![Complex::new(0.0, 0.0); total];
plan.execute(&input, &mut output);
let expected_dc: f64 = (0..total).map(|i| i as f64).sum();
let diff = (output[0].re - expected_dc).abs();
assert!(
diff < 1e-9,
"3D DC bin mismatch: got={} expected={}",
output[0].re,
expected_dc
);
}
#[test]
fn test_fftw_r2c_execute() {
let n = 16usize;
let input: Vec<f64> = (0..n).map(|i| i as f64).collect();
let mut output = vec![Complex::new(0.0f64, 0.0); n / 2 + 1];
let plan = fftw_plan_dft_r2c_1d(n, Flags::ESTIMATE).unwrap();
plan.execute_r2c(&input, &mut output);
let expected_dc: f64 = (0..n).map(|i| i as f64).sum();
let diff = (output[0].re - expected_dc).abs();
assert!(
diff < 1e-9,
"R2C DC bin mismatch: got={} expected={}",
output[0].re,
expected_dc
);
}
#[test]
fn test_fftw_c2r_roundtrip() {
let n = 16usize;
let original: Vec<f64> = (0..n).map(|i| i as f64).collect();
let mut freq = vec![Complex::new(0.0f64, 0.0); n / 2 + 1];
let r2c = fftw_plan_dft_r2c_1d(n, Flags::ESTIMATE).unwrap();
r2c.execute_r2c(&original, &mut freq);
let mut recovered = vec![0.0f64; n];
let c2r = fftw_plan_dft_c2r_1d(n, Flags::ESTIMATE).unwrap();
c2r.execute_c2r(&freq, &mut recovered);
for (got, exp) in recovered.iter().zip(original.iter()) {
let diff = (got - exp).abs();
assert!(
diff < 1e-9,
"C2R roundtrip mismatch: got={got} expected={exp}"
);
}
}
#[test]
fn test_fftw_plan_many_dft_some() {
let plan = fftw_plan_many_dft::<f64>(1, &[32], 4, Direction::Forward, Flags::ESTIMATE);
assert!(
plan.is_some(),
"plan_many_dft should succeed for valid args"
);
let plan = plan.unwrap();
assert_eq!(plan.batch_count(), 4);
assert_eq!(plan.transform_size(), 32);
}
#[test]
fn test_fftw_plan_many_dft_rank_mismatch_none() {
let plan = fftw_plan_many_dft::<f64>(2, &[32], 4, Direction::Forward, Flags::ESTIMATE);
assert!(plan.is_none(), "rank mismatch should return None");
}
#[test]
fn test_fftw_plan_many_dft_zero_howmany_none() {
let plan = fftw_plan_many_dft::<f64>(1, &[32], 0, Direction::Forward, Flags::ESTIMATE);
assert!(plan.is_none(), "zero howmany should return None");
}
#[test]
fn test_fftw_plan_many_dft_zero_dim_none() {
let plan = fftw_plan_many_dft::<f64>(1, &[0], 4, Direction::Forward, Flags::ESTIMATE);
assert!(plan.is_none(), "zero dimension should return None");
}
#[test]
fn test_fftw_plan_many_dft_execute() {
let n = 8usize;
let howmany = 3usize;
let total = n * howmany;
let plan = fftw_plan_many_dft::<f64>(1, &[n], howmany, Direction::Forward, Flags::ESTIMATE)
.unwrap();
let input: Vec<Complex<f64>> = (0..total).map(|i| Complex::new(i as f64, 0.0)).collect();
let mut output = vec![Complex::new(0.0, 0.0); total];
plan.execute(&input, &mut output);
let non_zero = output
.iter()
.any(|c| c.re.abs() > 1e-12 || c.im.abs() > 1e-12);
assert!(
non_zero,
"batch execute should produce non-trivially-zero output"
);
}
#[test]
fn test_fftw_plan_many_dft_f32() {
let plan = fftw_plan_many_dft::<f32>(1, &[16], 2, Direction::Forward, Flags::ESTIMATE);
assert!(plan.is_some());
let plan = plan.unwrap();
assert_eq!(plan.batch_count(), 2);
}
}