#[cfg(feature = "wgpu_fft")]
mod inner {
use super::super::pipeline::GpuFftPipeline;
use super::super::types::{FftDirection, GpuFftConfig, NormalizationMode};
use crate::error::FFTError;
use scirs2_core::numeric::Complex64;
use wgpu::{Backends, Instance, InstanceDescriptor, PowerPreference, RequestAdapterOptions};
#[derive(Debug, thiserror::Error)]
pub enum FftBackendError {
#[error("no wgpu adapter available (GPU unavailable or unsupported)")]
NoAdapter,
#[error("wgpu device creation failed: {0}")]
DeviceCreation(String),
#[error("WGSL shader compilation failed: {0}")]
ShaderCompilation(String),
#[error("GPU buffer operation failed: {0}")]
Buffer(String),
#[error("wgpu FFT requires a power-of-two input length; got {0}")]
NonPowerOfTwo(usize),
}
impl From<FftBackendError> for FFTError {
fn from(e: FftBackendError) -> Self {
FFTError::BackendError(e.to_string())
}
}
pub fn gpu_available() -> bool {
let instance_desc = InstanceDescriptor {
backends: Backends::all(),
flags: wgpu::InstanceFlags::default(),
memory_budget_thresholds: Default::default(),
backend_options: Default::default(),
display: None,
};
let instance = Instance::new(instance_desc);
pollster::block_on(async {
instance
.request_adapter(&RequestAdapterOptions {
power_preference: PowerPreference::default(),
compatible_surface: None,
force_fallback_adapter: false,
})
.await
.is_ok()
})
}
pub fn fft_wgpu(input: &[Complex64], _inverse: bool) -> Result<Vec<Complex64>, FFTError> {
let n = input.len();
if !n.is_power_of_two() {
return Err(FftBackendError::NonPowerOfTwo(n).into());
}
let adapter_available = gpu_available();
if !adapter_available {
return Err(FftBackendError::NoAdapter.into());
}
let direction = if _inverse {
FftDirection::Inverse
} else {
FftDirection::Forward
};
let norm = if _inverse {
NormalizationMode::Backward
} else {
NormalizationMode::None
};
let pipeline = GpuFftPipeline::new(GpuFftConfig {
normalization: norm,
..GpuFftConfig::default()
});
let mut buf = input.to_vec();
pipeline
.execute(&mut buf, n, direction)
.map_err(|e| FFTError::BackendError(e.to_string()))?;
Ok(buf)
}
}
#[cfg(feature = "wgpu_fft")]
pub use inner::{fft_wgpu, gpu_available, FftBackendError};
#[cfg(all(test, feature = "wgpu_fft"))]
mod tests {
use super::gpu_available;
#[test]
fn test_gpu_available_returns_bool() {
let result: bool = gpu_available();
println!("gpu_available() = {result}");
}
}