#![cfg(any(feature = "vkfft", feature = "cufft"))]
use num_complex::Complex32;
use crate::{Fft, FftDirection, GpuError};
#[cfg(feature = "vkfft")]
use crate::{Device, GpuBuffer, VkFftBackend};
#[cfg(feature = "cufft")]
use crate::{CuFftBackend, CudaBuffer, CudaDevice};
#[allow(clippy::large_enum_variant)]
pub enum UniFftBackend {
#[cfg(feature = "vkfft")]
Vulkan(VkFftBackend),
#[cfg(feature = "cufft")]
Cuda(CuFftBackend),
}
impl UniFftBackend {
#[cfg(feature = "vkfft")]
pub fn vulkan(dev: &Device) -> Result<Self, GpuError> {
Ok(Self::Vulkan(VkFftBackend::new(dev)?))
}
#[cfg(feature = "cufft")]
pub fn cuda(dev: &CudaDevice) -> Result<Self, GpuError> {
Ok(Self::Cuda(CuFftBackend::new(dev)?))
}
}
#[allow(clippy::large_enum_variant)]
pub enum UniBuffer {
#[cfg(feature = "vkfft")]
Vulkan(GpuBuffer<Complex32>),
#[cfg(feature = "cufft")]
Cuda(CudaBuffer),
}
impl UniBuffer {
pub fn from_slice(engine: &UniFftBackend, host: &[Complex32]) -> Result<Self, GpuError> {
match engine {
#[cfg(feature = "vkfft")]
UniFftBackend::Vulkan(b) => Ok(Self::Vulkan(GpuBuffer::<Complex32>::from_slice(
b.device(),
host,
wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
)?)),
#[cfg(feature = "cufft")]
UniFftBackend::Cuda(b) => Ok(Self::Cuda(CudaBuffer::from_slice(b.device(), host)?)),
}
}
pub fn to_vec(&self, engine: &UniFftBackend) -> Result<Vec<Complex32>, GpuError> {
match (self, engine) {
#[cfg(feature = "vkfft")]
(Self::Vulkan(b), UniFftBackend::Vulkan(e)) => b.to_vec(e.device()),
#[cfg(feature = "cufft")]
(Self::Cuda(b), _) => b.to_vec(),
#[allow(unreachable_patterns)]
_ => Err(GpuError::BackendMismatch),
}
}
pub fn len(&self) -> usize {
match self {
#[cfg(feature = "vkfft")]
Self::Vulkan(b) => b.len(),
#[cfg(feature = "cufft")]
Self::Cuda(b) => b.len(),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Fft for UniFftBackend {
type Buffer = UniBuffer;
fn fft_1d(
&mut self,
buf: &mut Self::Buffer,
n: u32,
batch: u32,
direction: FftDirection,
) -> Result<(), GpuError> {
match (self, buf) {
#[cfg(feature = "vkfft")]
(Self::Vulkan(e), UniBuffer::Vulkan(b)) => e.fft_1d(b, n, batch, direction),
#[cfg(feature = "cufft")]
(Self::Cuda(e), UniBuffer::Cuda(b)) => e.fft_1d(b, n, batch, direction),
#[allow(unreachable_patterns)]
_ => Err(GpuError::BackendMismatch),
}
}
fn fft_2d(
&mut self,
buf: &mut Self::Buffer,
nx: u32,
ny: u32,
batch: u32,
direction: FftDirection,
) -> Result<(), GpuError> {
match (self, buf) {
#[cfg(feature = "vkfft")]
(Self::Vulkan(e), UniBuffer::Vulkan(b)) => e.fft_2d(b, nx, ny, batch, direction),
#[cfg(feature = "cufft")]
(Self::Cuda(e), UniBuffer::Cuda(b)) => e.fft_2d(b, nx, ny, batch, direction),
#[allow(unreachable_patterns)]
_ => Err(GpuError::BackendMismatch),
}
}
fn fft_3d(
&mut self,
buf: &mut Self::Buffer,
nx: u32,
ny: u32,
nz: u32,
direction: FftDirection,
) -> Result<(), GpuError> {
match (self, buf) {
#[cfg(feature = "vkfft")]
(Self::Vulkan(e), UniBuffer::Vulkan(b)) => e.fft_3d(b, nx, ny, nz, direction),
#[cfg(feature = "cufft")]
(Self::Cuda(e), UniBuffer::Cuda(b)) => e.fft_3d(b, nx, ny, nz, direction),
#[allow(unreachable_patterns)]
_ => Err(GpuError::BackendMismatch),
}
}
}