use std::f64::consts::PI;
pub type Butterfly4Result = ((f64, f64), (f64, f64), (f64, f64), (f64, f64));
#[inline(always)]
fn cadd(a: (f64, f64), b: (f64, f64)) -> (f64, f64) {
(a.0 + b.0, a.1 + b.1)
}
#[inline(always)]
fn csub(a: (f64, f64), b: (f64, f64)) -> (f64, f64) {
(a.0 - b.0, a.1 - b.1)
}
#[inline(always)]
fn cmul(a: (f64, f64), b: (f64, f64)) -> (f64, f64) {
(a.0 * b.0 - a.1 * b.1, a.0 * b.1 + a.1 * b.0)
}
#[inline(always)]
fn cscale(a: (f64, f64), s: f64) -> (f64, f64) {
(a.0 * s, a.1 * s)
}
#[inline(always)]
fn cmul_neg_i(a: (f64, f64)) -> (f64, f64) {
(a.1, -a.0)
}
#[inline(always)]
pub fn butterfly_2(a: (f64, f64), b: (f64, f64), w: (f64, f64)) -> ((f64, f64), (f64, f64)) {
let wb = cmul(w, b);
(cadd(a, wb), csub(a, wb))
}
#[allow(clippy::many_single_char_names)]
#[inline(always)]
pub fn butterfly_4(
a: (f64, f64),
b: (f64, f64),
c: (f64, f64),
d: (f64, f64),
w1: (f64, f64),
w2: (f64, f64),
w3: (f64, f64),
) -> Butterfly4Result {
let wb1 = cmul(w1, b);
let wb2 = cmul(w2, c);
let wb3 = cmul(w3, d);
let t0 = cadd(a, wb2); let t1 = csub(a, wb2); let t2 = cadd(wb1, wb3); let t3 = cmul_neg_i(csub(wb1, wb3));
(cadd(t0, t2), cadd(t1, t3), csub(t0, t2), csub(t1, t3))
}
pub fn factorize(mut n: usize) -> Vec<usize> {
let mut factors = Vec::new();
while n % 4 == 0 {
factors.push(4);
n /= 4;
}
if n % 2 == 0 {
factors.push(2);
n /= 2;
}
for &p in &[3usize, 5, 7] {
while n % p == 0 {
factors.push(p);
n /= p;
}
}
if n > 1 {
factors.push(n);
}
factors
}
pub fn compute_twiddles(n: usize) -> Vec<(f64, f64)> {
if n == 0 {
return Vec::new();
}
let mut table = Vec::with_capacity(n);
let theta = -2.0 * PI / n as f64;
let (sin_t, cos_t) = theta.sin_cos();
let w1 = (cos_t, sin_t);
let mut w = (1.0_f64, 0.0_f64);
for _ in 0..n {
table.push(w);
w = cmul(w, w1);
}
table
}
pub fn compute_twiddles_inv(n: usize) -> Vec<(f64, f64)> {
if n == 0 {
return Vec::new();
}
let mut table = Vec::with_capacity(n);
let theta = 2.0 * PI / n as f64;
let (sin_t, cos_t) = theta.sin_cos();
let w1 = (cos_t, sin_t);
let mut w = (1.0_f64, 0.0_f64);
for _ in 0..n {
table.push(w);
w = cmul(w, w1);
}
table
}
pub fn fft_1d_pow2(data: &[(f64, f64)]) -> Vec<(f64, f64)> {
let n = data.len();
debug_assert!(
n.is_power_of_two(),
"fft_1d_pow2 requires power-of-2 length"
);
if n == 0 {
return Vec::new();
}
if n == 1 {
return vec![data[0]];
}
let mut output = data.to_vec();
let bits = n.trailing_zeros() as usize;
for i in 0..n {
let rev = bit_reverse(i, bits);
if rev > i {
output.swap(i, rev);
}
}
let mut len = 2usize;
while len <= n {
let half = len / 2;
let theta = -PI / half as f64;
let w_step = (theta.cos(), theta.sin());
let mut j = 0;
while j < n {
let mut w = (1.0f64, 0.0f64);
for k in 0..half {
let u = output[j + k];
let v = cmul(w, output[j + k + half]);
output[j + k] = cadd(u, v);
output[j + k + half] = csub(u, v);
w = cmul(w, w_step);
}
j += len;
}
len <<= 1;
}
output
}
pub fn ifft_1d_pow2_raw(data: &[(f64, f64)]) -> Vec<(f64, f64)> {
let conj: Vec<(f64, f64)> = data.iter().map(|&(re, im)| (re, -im)).collect();
let fft_out = fft_1d_pow2(&conj);
fft_out.into_iter().map(|(re, im)| (re, -im)).collect()
}
#[inline(always)]
fn bit_reverse(mut x: usize, bits: usize) -> usize {
let mut result = 0usize;
for _ in 0..bits {
result = (result << 1) | (x & 1);
x >>= 1;
}
result
}
fn needs_bluestein(n: usize) -> bool {
let factors = factorize(n);
factors.iter().any(|&f| f > 7)
}
#[inline]
fn next_pow2(n: usize) -> usize {
if n <= 1 {
1
} else {
n.next_power_of_two()
}
}
pub fn bluestein_fft(input: &[(f64, f64)]) -> Vec<(f64, f64)> {
let n = input.len();
if n <= 1 {
return input.to_vec();
}
let chirp: Vec<(f64, f64)> = (0..n)
.map(|k| {
let theta = -PI * (k * k) as f64 / n as f64;
let (s, c) = theta.sin_cos();
(c, s)
})
.collect();
let m = next_pow2(2 * n - 1);
let mut a_padded = vec![(0.0f64, 0.0f64); m];
for k in 0..n {
a_padded[k] = cmul(input[k], chirp[k]);
}
let mut b_padded = vec![(0.0f64, 0.0f64); m];
for k in 0..n {
b_padded[k] = (chirp[k].0, -chirp[k].1); }
for k in 1..n {
b_padded[m - k] = b_padded[k];
}
let a_fft = fft_1d_pow2(&a_padded);
let b_fft = fft_1d_pow2(&b_padded);
let prod_fft: Vec<(f64, f64)> = a_fft
.iter()
.zip(b_fft.iter())
.map(|(&a, &b)| cmul(a, b))
.collect();
let conv_raw = ifft_1d_pow2_raw(&prod_fft);
let scale = 1.0 / m as f64;
let conv: Vec<(f64, f64)> = conv_raw.into_iter().map(|x| cscale(x, scale)).collect();
conv[..n]
.iter()
.enumerate()
.map(|(k, &c)| cmul(c, chirp[k]))
.collect()
}
pub fn bluestein_ifft_raw(input: &[(f64, f64)]) -> Vec<(f64, f64)> {
let conj_in: Vec<(f64, f64)> = input.iter().map(|&(re, im)| (re, -im)).collect();
let fft_out = bluestein_fft(&conj_in);
fft_out.into_iter().map(|(re, im)| (re, -im)).collect()
}
fn mixed_radix_rec(data: &[(f64, f64)], sign: f64) -> Vec<(f64, f64)> {
let n = data.len();
if n <= 1 {
return data.to_vec();
}
if n.is_power_of_two() {
if sign < 0.0 {
return fft_1d_pow2(data);
} else {
return ifft_1d_pow2_raw(data);
}
}
let factors = factorize(n);
let radix = factors[0];
let m = n / radix;
let sub_ffts: Vec<Vec<(f64, f64)>> = (0..radix)
.map(|j| {
let sub_seq: Vec<(f64, f64)> = (0..m).map(|k| data[k * radix + j]).collect();
mixed_radix_rec(&sub_seq, sign)
})
.collect();
let mut out = vec![(0.0f64, 0.0f64); n];
for k in 0..m {
for j_out in 0..radix {
let out_idx = k + m * j_out;
let mut sum = (0.0f64, 0.0f64);
for j in 0..radix {
let exp = j * out_idx;
let theta = sign * 2.0 * PI * exp as f64 / n as f64;
let (sin_t, cos_t) = theta.sin_cos();
let w = (cos_t, sin_t);
sum = cadd(sum, cmul(w, sub_ffts[j][k]));
}
out[out_idx] = sum;
}
}
out
}
fn small_dft(scratch: &[(f64, f64)], sign: f64) -> Vec<(f64, f64)> {
let r = scratch.len();
match r {
1 => vec![scratch[0]],
2 => {
let a = scratch[0];
let b = scratch[1];
vec![cadd(a, b), csub(a, b)]
}
3 => {
let w = {
let theta = sign * 2.0 * PI / 3.0;
let (s, c) = theta.sin_cos();
(c, s)
};
let w2 = cmul(w, w);
let x0 = scratch[0];
let x1 = scratch[1];
let x2 = scratch[2];
vec![
cadd(x0, cadd(x1, x2)),
cadd(x0, cadd(cmul(w, x1), cmul(w2, x2))),
cadd(x0, cadd(cmul(w2, x1), cmul(cmul(w2, w), x2))),
]
}
4 => {
let x0 = scratch[0];
let x1 = scratch[1];
let x2 = scratch[2];
let x3 = scratch[3];
let t0 = cadd(x0, x2);
let t1 = csub(x0, x2);
let t2 = cadd(x1, x3);
let t3_raw = csub(x1, x3);
let t3 = if sign < 0.0 {
cmul_neg_i(t3_raw)
} else {
(-t3_raw.1, t3_raw.0) };
vec![cadd(t0, t2), cadd(t1, t3), csub(t0, t2), csub(t1, t3)]
}
_ => {
(0..r)
.map(|k| {
scratch
.iter()
.enumerate()
.fold((0.0f64, 0.0f64), |acc, (j, &x)| {
let theta = sign * 2.0 * PI * (k * j) as f64 / r as f64;
let (s, c) = theta.sin_cos();
cadd(acc, cmul(x, (c, s)))
})
})
.collect()
}
}
}
pub fn fft_1d(input: &[(f64, f64)]) -> Vec<(f64, f64)> {
let n = input.len();
if n == 0 {
return Vec::new();
}
if n == 1 {
return vec![input[0]];
}
if needs_bluestein(n) {
return bluestein_fft(input);
}
if n.is_power_of_two() {
return fft_1d_pow2(input);
}
mixed_radix_rec(input, -1.0)
}
pub fn ifft_1d_raw(input: &[(f64, f64)]) -> Vec<(f64, f64)> {
let n = input.len();
if n == 0 {
return Vec::new();
}
if n == 1 {
return vec![input[0]];
}
if needs_bluestein(n) {
return bluestein_ifft_raw(input);
}
if n.is_power_of_two() {
return ifft_1d_pow2_raw(input);
}
mixed_radix_rec(input, 1.0)
}
pub fn ifft_1d(input: &[(f64, f64)]) -> Vec<(f64, f64)> {
let n = input.len();
let raw = ifft_1d_raw(input);
let scale = 1.0 / n as f64;
raw.into_iter().map(|x| cscale(x, scale)).collect()
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
fn naive_dft(input: &[(f64, f64)]) -> Vec<(f64, f64)> {
let n = input.len();
(0..n)
.map(|k| {
input
.iter()
.enumerate()
.fold((0.0f64, 0.0f64), |acc, (j, &x)| {
let theta = -2.0 * PI * k as f64 * j as f64 / n as f64;
let (s, c) = theta.sin_cos();
cadd(acc, cmul(x, (c, s)))
})
})
.collect()
}
#[test]
fn test_radix2_power_of_2() {
let input: Vec<(f64, f64)> = vec![
(1.0, 0.0),
(2.0, 0.0),
(3.0, 0.0),
(4.0, 0.0),
(5.0, 0.0),
(6.0, 0.0),
(7.0, 0.0),
(8.0, 0.0),
];
let fft_out = fft_1d(&input);
let ref_out = naive_dft(&input);
assert_eq!(fft_out.len(), 8);
for (a, b) in fft_out.iter().zip(ref_out.iter()) {
assert_relative_eq!(a.0, b.0, epsilon = 1e-9);
assert_relative_eq!(a.1, b.1, epsilon = 1e-9);
}
}
#[test]
fn test_radix4_size_16() {
let input: Vec<(f64, f64)> = (0..16).map(|i| (i as f64, 0.0)).collect();
let fft_out = fft_1d(&input);
let ref_out = naive_dft(&input);
assert_eq!(fft_out.len(), 16);
for (a, b) in fft_out.iter().zip(ref_out.iter()) {
assert_relative_eq!(a.0, b.0, epsilon = 1e-8);
assert_relative_eq!(a.1, b.1, epsilon = 1e-8);
}
}
#[test]
fn test_mixed_radix_size_12() {
let input: Vec<(f64, f64)> = (0..12).map(|i| (i as f64, 0.0)).collect();
let fft_out = fft_1d(&input);
let ref_out = naive_dft(&input);
for (a, b) in fft_out.iter().zip(ref_out.iter()) {
assert_relative_eq!(a.0, b.0, epsilon = 1e-8);
assert_relative_eq!(a.1, b.1, epsilon = 1e-8);
}
}
#[test]
fn test_mixed_radix_size_30() {
let input: Vec<(f64, f64)> = (0..30).map(|i| (i as f64, 0.0)).collect();
let fft_out = fft_1d(&input);
let ref_out = naive_dft(&input);
for (a, b) in fft_out.iter().zip(ref_out.iter()) {
assert_relative_eq!(a.0, b.0, epsilon = 1e-7);
assert_relative_eq!(a.1, b.1, epsilon = 1e-7);
}
}
#[test]
fn test_bluestein_prime_size_7() {
let input: Vec<(f64, f64)> = (0..7).map(|i| (i as f64 + 1.0, 0.0)).collect();
let fft_out = bluestein_fft(&input);
let ref_out = naive_dft(&input);
assert_eq!(fft_out.len(), 7);
for (a, b) in fft_out.iter().zip(ref_out.iter()) {
assert_relative_eq!(a.0, b.0, epsilon = 1e-8);
assert_relative_eq!(a.1, b.1, epsilon = 1e-8);
}
}
#[test]
fn test_bluestein_prime_size_11() {
let input: Vec<(f64, f64)> = (0..11).map(|i| ((i + 1) as f64, 0.0)).collect();
let fft_out = bluestein_fft(&input);
let ref_out = naive_dft(&input);
assert_eq!(fft_out.len(), 11);
for (a, b) in fft_out.iter().zip(ref_out.iter()) {
assert_relative_eq!(a.0, b.0, epsilon = 1e-7);
assert_relative_eq!(a.1, b.1, epsilon = 1e-7);
}
}
#[test]
fn test_fft_1d_roundtrip() {
let input: Vec<(f64, f64)> = (0..32).map(|i| (i as f64 * 0.1, 0.0)).collect();
let freq = fft_1d(&input);
let recovered = ifft_1d(&freq);
for (a, b) in input.iter().zip(recovered.iter()) {
assert_relative_eq!(a.0, b.0, epsilon = 1e-10);
assert_relative_eq!(a.1, b.1, epsilon = 1e-10);
}
}
#[test]
fn test_twiddle_precomputation() {
let n = 8;
let twiddles = compute_twiddles(n);
for (k, &(re, im)) in twiddles.iter().enumerate() {
let theta = -2.0 * PI * k as f64 / n as f64;
assert_relative_eq!(re, theta.cos(), epsilon = 1e-14);
assert_relative_eq!(im, theta.sin(), epsilon = 1e-14);
}
}
#[test]
fn test_factorize_powers_of_2() {
let f = factorize(8);
assert_eq!(f.iter().product::<usize>(), 8);
for &x in &f {
assert!(matches!(x, 2 | 3 | 4 | 5 | 7));
}
}
#[test]
fn test_factorize_composite() {
let f = factorize(12);
assert_eq!(f.iter().product::<usize>(), 12);
assert!(f.contains(&3));
}
}