wgsl-fft 0.1.0

GPU-accelerated FFT using Webgpu compute shaders
Documentation
use std::any::Any;

use crate::FftExecutor;
use cudarc::cufft::{sys, CudaFft, FftDirection};
use cudarc::driver::{CudaContext, CudaStream};
use num_complex::Complex;
use std::error::Error;
use std::io::{Error as IoError, ErrorKind};
use std::sync::Arc;

/// Wrapper for cuFFT (via cudarc) to compare performance with our WebGPU implementation.
pub struct CuFft {
    plan: CudaFft,
    stream: Arc<CudaStream>,
    fft_size: usize,
}

impl FftExecutor for CuFft {
    fn name(&self) -> &str {
        "cuFFT (NVIDIA Gold Standard)"
    }

    fn fft(&self, inputs: &[Vec<Complex<f32>>]) -> Result<Vec<Vec<Complex<f32>>>, Box<dyn Error>> {
        self.batch_fft(inputs)
    }

    fn ifft(&self, inputs: &[Vec<Complex<f32>>]) -> Result<Vec<Vec<Complex<f32>>>, Box<dyn Error>> {
        self.batch_ifft(inputs)
    }

    fn as_any(&self) -> &dyn Any {
        self
    }
}

// Note: cuFFT uses a completely different architecture (CUDA vs wgpu)
// so it needs a specialized GPU-only benchmark implementation
// For now, cuFFT will use the full benchmarking path

impl CuFft {
    /// Create a new cuFFT instance.
    pub fn new(fft_size: usize) -> Result<Self, Box<dyn Error>> {
        if fft_size == 0 {
            return Err(IoError::new(ErrorKind::InvalidInput, "fft_size must be > 0").into());
        }

        let ctx = CudaContext::new(0)?;
        let stream = ctx.default_stream();
        let single_plan = CudaFft::plan_1d(
            fft_size as i32,
            sys::cufftType::CUFFT_C2C,
            1,
            stream.clone(),
        )?;

        Ok(Self {
            plan: single_plan,
            stream,
            fft_size,
        })
    }

    fn ensure_input_size(&self, input: &[Complex<f32>]) -> Result<(), Box<dyn Error>> {
        if input.len() != self.fft_size {
            return Err(IoError::new(
                ErrorKind::InvalidInput,
                format!(
                    "input length {} does not match planned fft_size {}",
                    input.len(),
                    self.fft_size
                ),
            )
            .into());
        }
        Ok(())
    }

    fn validate_batch_inputs(inputs: &[Vec<Complex<f32>>]) -> Result<usize, Box<dyn Error>> {
        if inputs.is_empty() {
            return Ok(0);
        }

        let n = inputs[0].len();
        if n == 0 {
            return Err(IoError::new(
                ErrorKind::InvalidInput,
                "batch input vectors must be non-empty",
            )
            .into());
        }

        for (idx, input) in inputs.iter().enumerate() {
            if input.len() != n {
                return Err(IoError::new(
                    ErrorKind::InvalidInput,
                    format!(
                        "all batch inputs must have the same length; input[0]={n}, input[{idx}]={}",
                        input.len()
                    ),
                )
                .into());
            }
        }

        Ok(n)
    }

    fn to_cuda_complex(values: &[Complex<f32>]) -> Vec<sys::float2> {
        values
            .iter()
            .map(|c| sys::float2 { x: c.re, y: c.im })
            .collect()
    }

    fn from_cuda_complex(values: &[sys::float2]) -> Vec<Complex<f32>> {
        values.iter().map(|c| Complex::new(c.x, c.y)).collect()
    }

    /// Convert from Complex<f32> vectors to flattened sys::float2 format
    pub fn to_flattened_input(inputs: &[Vec<Complex<f32>>]) -> Vec<sys::float2> {
        inputs
            .iter()
            .flat_map(|v| v.iter().map(|c| sys::float2 { x: c.re, y: c.im }))
            .collect()
    }

    /// Convert from flattened sys::float2 format to Complex<f32> vectors
    pub fn from_flattened_output(
        flat_output: Vec<sys::float2>,
        n: usize,
    ) -> Vec<Vec<Complex<f32>>> {
        flat_output
            .chunks(n)
            .map(|chunk| chunk.iter().map(|c| Complex::new(c.x, c.y)).collect())
            .collect()
    }

    fn scale_inverse(output: &mut [Complex<f32>], n: usize) {
        let scale = 1.0 / n as f32;
        for c in output {
            c.re *= scale;
            c.im *= scale;
        }
    }

    /// Perform forward FFT.
    pub fn fft(&self, input: &[Complex<f32>]) -> Result<Vec<Complex<f32>>, Box<dyn Error>> {
        self.ensure_input_size(input)?;

        let host_input = Self::to_cuda_complex(input);
        let mut d_input = self.stream.clone_htod(&host_input)?;
        let mut d_output = self.stream.alloc_zeros::<sys::float2>(self.fft_size)?;

        self.plan
            .exec_c2c(&mut d_input, &mut d_output, FftDirection::Forward)?;

        let host_output: Vec<sys::float2> = self.stream.clone_dtoh(&d_output)?;
        Ok(Self::from_cuda_complex(&host_output))
    }

    /// Perform inverse FFT.
    pub fn ifft(&self, input: &[Complex<f32>]) -> Result<Vec<Complex<f32>>, Box<dyn Error>> {
        self.ensure_input_size(input)?;

        let host_input = Self::to_cuda_complex(input);
        let mut d_input = self.stream.clone_htod(&host_input)?;
        let mut d_output = self.stream.alloc_zeros::<sys::float2>(self.fft_size)?;

        self.plan
            .exec_c2c(&mut d_input, &mut d_output, FftDirection::Inverse)?;

        let host_output: Vec<sys::float2> = self.stream.clone_dtoh(&d_output)?;
        let mut result = Self::from_cuda_complex(&host_output);
        Self::scale_inverse(&mut result, self.fft_size);
        Ok(result)
    }

    /// Perform batch FFT on pre-flattened input (optimized for benchmarking).
    /// Input: flattened vector of complex pairs [re0, im0, re1, im1, ...]
    /// Output: flattened vector of complex pairs
    pub fn batch_fft_flattened(
        &self,
        flat_input: &[sys::float2],
        n: usize,
        batch_size: usize,
    ) -> Result<Vec<sys::float2>, Box<dyn Error>> {
        if flat_input.is_empty() {
            return Ok(Vec::new());
        }

        let mut d_input = self.stream.clone_htod(flat_input)?;
        let mut d_output = self.stream.alloc_zeros::<sys::float2>(n * batch_size)?;

        self.plan
            .exec_c2c(&mut d_input, &mut d_output, FftDirection::Forward)?;

        let flat_output: Vec<sys::float2> = self.stream.clone_dtoh(&d_output)?;
        Ok(flat_output)
    }

    /// Perform batch FFT.
    pub fn batch_fft(
        &self,
        inputs: &[Vec<Complex<f32>>],
    ) -> Result<Vec<Vec<Complex<f32>>>, Box<dyn Error>> {
        if inputs.is_empty() {
            return Ok(Vec::new());
        }

        let n = Self::validate_batch_inputs(inputs)?;
        let batch_size = inputs.len();

        let flat_input: Vec<sys::float2> = inputs
            .iter()
            .flat_map(|v| v.iter().map(|c| sys::float2 { x: c.re, y: c.im }))
            .collect();

        let flat_output = self.batch_fft_flattened(&flat_input, n, batch_size)?;
        let result = flat_output.chunks(n).map(Self::from_cuda_complex).collect();

        Ok(result)
    }

    /// Perform batch IFFT on pre-flattened input (optimized for benchmarking).
    /// Input: flattened vector of complex pairs [re0, im0, re1, im1, ...]
    /// Output: flattened vector of complex pairs (scaled by 1/N)
    pub fn batch_ifft_flattened(
        &self,
        flat_input: &[sys::float2],
        n: usize,
        batch_size: usize,
    ) -> Result<Vec<sys::float2>, Box<dyn Error>> {
        if flat_input.is_empty() {
            return Ok(Vec::new());
        }

        let mut d_input = self.stream.clone_htod(flat_input)?;
        let mut d_output = self.stream.alloc_zeros::<sys::float2>(n * batch_size)?;

        self.plan
            .exec_c2c(&mut d_input, &mut d_output, FftDirection::Inverse)?;

        let mut flat_output: Vec<sys::float2> = self.stream.clone_dtoh(&d_output)?;

        // Apply scaling for inverse transform
        let scale = 1.0 / n as f32;
        for c in &mut flat_output {
            c.x *= scale;
            c.y *= scale;
        }

        Ok(flat_output)
    }

    fn batch_ifft(
        &self,
        inputs: &[Vec<Complex<f32>>],
    ) -> Result<Vec<Vec<Complex<f32>>>, Box<dyn Error>> {
        if inputs.is_empty() {
            return Ok(Vec::new());
        }

        let n = Self::validate_batch_inputs(inputs)?;
        let batch_size = inputs.len();

        let flat_input: Vec<sys::float2> = inputs
            .iter()
            .flat_map(|v| v.iter().map(|c| sys::float2 { x: c.re, y: c.im }))
            .collect();

        let flat_output = self.batch_ifft_flattened(&flat_input, n, batch_size)?;
        let mut result: Vec<Vec<Complex<f32>>> =
            flat_output.chunks(n).map(Self::from_cuda_complex).collect();

        Ok(result)
    }

    /// Check if CUDA is available.
    pub fn is_available() -> bool {
        Self::new(16).is_ok()
    }
}