use crate::api::{Direction, Flags, Plan};
use crate::kernel::{Complex, Float};
#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
pub fn fft_pruned_input<T: Float>(
nonzero_inputs: &[(usize, Complex<T>)],
n: usize,
) -> Vec<Complex<T>> {
if n == 0 {
return Vec::new();
}
if nonzero_inputs.is_empty() {
return vec![Complex::<T>::zero(); n];
}
let k = nonzero_inputs.len();
let crossover = (n as f64).log2().ceil() as usize;
if k <= crossover {
dft_sparse_input(nonzero_inputs, n)
} else {
fft_full_from_sparse(nonzero_inputs, n)
}
}
fn dft_sparse_input<T: Float>(nonzero_inputs: &[(usize, Complex<T>)], n: usize) -> Vec<Complex<T>> {
let mut output = vec![Complex::<T>::zero(); n];
let two_pi = <T as Float>::PI + <T as Float>::PI;
for k in 0..n {
let mut sum = Complex::<T>::zero();
for &(m, value) in nonzero_inputs {
if m < n {
let angle = two_pi * T::from_usize(k * m) / T::from_usize(n);
let (sin_a, cos_a) = Float::sin_cos(angle);
let twiddle = Complex::new(cos_a, T::ZERO - sin_a);
sum = sum + value * twiddle;
}
}
output[k] = sum;
}
output
}
fn fft_full_from_sparse<T: Float>(
nonzero_inputs: &[(usize, Complex<T>)],
n: usize,
) -> Vec<Complex<T>> {
let mut input = vec![Complex::<T>::zero(); n];
for &(idx, value) in nonzero_inputs {
if idx < n {
input[idx] = value;
}
}
let plan = match Plan::dft_1d(n, Direction::Forward, Flags::ESTIMATE) {
Some(p) => p,
None => return vec![Complex::<T>::zero(); n],
};
let mut output = vec![Complex::<T>::zero(); n];
plan.execute(&input, &mut output);
output
}
#[allow(dead_code)]
pub fn fft_pruned_input_butterfly<T: Float>(
nonzero_inputs: &[(usize, Complex<T>)],
n: usize,
) -> Vec<Complex<T>> {
if n == 0 || !n.is_power_of_two() {
return fft_full_from_sparse(nonzero_inputs, n);
}
let log_n = n.trailing_zeros() as usize;
let mut data = vec![Complex::<T>::zero(); n];
for &(idx, value) in nonzero_inputs {
if idx < n {
let rev_idx = bit_reverse(idx, log_n);
data[rev_idx] = value;
}
}
let mut has_data = vec![false; n];
for &(idx, _) in nonzero_inputs {
if idx < n {
has_data[bit_reverse(idx, log_n)] = true;
}
}
let two_pi = <T as Float>::PI + <T as Float>::PI;
for stage in 0..log_n {
let block_size = 1 << (stage + 1);
let half_block = block_size / 2;
for block_start in (0..n).step_by(block_size) {
for i in 0..half_block {
let idx1 = block_start + i;
let idx2 = block_start + i + half_block;
let k = i * (n / block_size);
let angle = two_pi * T::from_usize(k) / T::from_usize(n);
let (sin_a, cos_a) = Float::sin_cos(angle);
let twiddle = Complex::new(cos_a, T::ZERO - sin_a);
let a = data[idx1];
let b = data[idx2] * twiddle;
data[idx1] = a + b;
data[idx2] = a - b;
has_data[idx1] = has_data[idx1] || has_data[idx2];
has_data[idx2] = has_data[idx1];
}
}
}
data
}
fn bit_reverse(mut x: usize, bits: usize) -> usize {
let mut result = 0;
for _ in 0..bits {
result = (result << 1) | (x & 1);
x >>= 1;
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fft_pruned_input_empty() {
let output: Vec<Complex<f64>> = fft_pruned_input(&[], 64);
assert_eq!(output.len(), 64);
assert!(output.iter().all(|c| c.re == 0.0 && c.im == 0.0));
}
#[test]
fn test_fft_pruned_input_single() {
let n = 64;
let nonzero = vec![(0, Complex::new(1.0_f64, 0.0))];
let output = fft_pruned_input(&nonzero, n);
assert_eq!(output.len(), n);
for o in &output {
assert!((o.re - 1.0).abs() < 1e-10);
assert!(o.im.abs() < 1e-10);
}
}
#[test]
fn test_fft_pruned_input_vs_full() {
let n = 64;
let nonzero = vec![
(0, Complex::new(1.0_f64, 0.0)),
(10, Complex::new(0.5, 0.3)),
(32, Complex::new(-1.0, 0.5)),
];
let pruned_output = fft_pruned_input(&nonzero, n);
let mut input = vec![Complex::new(0.0_f64, 0.0); n];
for &(idx, val) in &nonzero {
input[idx] = val;
}
let plan = Plan::dft_1d(n, Direction::Forward, Flags::ESTIMATE).unwrap();
let mut full_output = vec![Complex::new(0.0_f64, 0.0); n];
plan.execute(&input, &mut full_output);
for i in 0..n {
let diff_re = (pruned_output[i].re - full_output[i].re).abs();
let diff_im = (pruned_output[i].im - full_output[i].im).abs();
assert!(diff_re < 1e-10, "Real mismatch at index {i}");
assert!(diff_im < 1e-10, "Imag mismatch at index {i}");
}
}
#[test]
fn test_fft_pruned_input_butterfly() {
let n = 64;
let nonzero = vec![(0, Complex::new(1.0_f64, 0.0)), (5, Complex::new(0.5, 0.3))];
let pruned_output = fft_pruned_input_butterfly(&nonzero, n);
let mut input = vec![Complex::new(0.0_f64, 0.0); n];
for &(idx, val) in &nonzero {
input[idx] = val;
}
let plan = Plan::dft_1d(n, Direction::Forward, Flags::ESTIMATE).unwrap();
let mut full_output = vec![Complex::new(0.0_f64, 0.0); n];
plan.execute(&input, &mut full_output);
for i in 0..n {
let diff_re = (pruned_output[i].re - full_output[i].re).abs();
let diff_im = (pruned_output[i].im - full_output[i].im).abs();
assert!(
diff_re < 1e-8,
"Real mismatch at index {}: {} vs {}",
i,
pruned_output[i].re,
full_output[i].re
);
assert!(
diff_im < 1e-8,
"Imag mismatch at index {}: {} vs {}",
i,
pruned_output[i].im,
full_output[i].im
);
}
}
#[test]
fn test_dft_sparse_input() {
let n = 32;
let nonzero = vec![(0, Complex::new(1.0_f64, 0.0)), (1, Complex::new(0.5, 0.5))];
let output = dft_sparse_input(&nonzero, n);
assert_eq!(output.len(), n);
let mut input = vec![Complex::new(0.0_f64, 0.0); n];
for &(idx, val) in &nonzero {
input[idx] = val;
}
let plan = Plan::dft_1d(n, Direction::Forward, Flags::ESTIMATE).unwrap();
let mut full_output = vec![Complex::new(0.0_f64, 0.0); n];
plan.execute(&input, &mut full_output);
for i in 0..n {
let diff_re = (output[i].re - full_output[i].re).abs();
let diff_im = (output[i].im - full_output[i].im).abs();
assert!(diff_re < 1e-10, "Real mismatch at index {i}");
assert!(diff_im < 1e-10, "Imag mismatch at index {i}");
}
}
}