#[non_exhaustive]#[repr(u16)]pub enum FftKind {
Fft = 0,
Ifft = 1,
Rfft = 2,
Irfft = 3,
FftShift = 4,
IfftShift = 5,
}Expand description
FFT-family op discriminant — Category U from the comprehensive plan.
Stored as u16 in crate::KernelSku::op when
category == OpCategory::Fft. Milestone 6.4 wires the four
canonical PyTorch / JAX 1-D FFTs (fft / ifft / rfft / irfft)
plus the two index-permutation helpers (fftshift / ifftshift).
1-D only for the trailblazer. Multi-D FFTs (fft2, fftn, …) and
arbitrary-axis FFTs follow in fanout sessions — they don’t require
new cuFFT bindings, just additional descriptor shape + plan glue.
Dtype coverage: f32 (single precision) and f64 (double
precision) only. cuFFT’s main API does not expose f16 / bf16
for native transforms. Callers needing reduced precision must cast
on either side. Spectrum-domain tensors use crate::Complex32 /
crate::Complex64 for the interleaved real/imag pairs.
Normalization: forward transforms are unnormalized; inverse
transforms are normalized by 1/N to match PyTorch’s
norm="backward" default. cuFFT itself returns N · IFFT(x); the
plan layer multiplies by 1/N after the inverse exec.
Variants (Non-exhaustive)§
This enum is marked as non-exhaustive
Fft = 0
y = FFT(x) — complex-to-complex forward transform (unnormalized).
PyTorch torch.fft.fft. Both input and output are complex with
the same shape [batch, n].
Ifft = 1
y = IFFT(x) — complex-to-complex inverse transform, normalized
by 1/N to match PyTorch’s norm="backward". PyTorch
torch.fft.ifft. Both input and output are complex [batch, n].
Rfft = 2
y = RFFT(x) — real-to-complex forward transform (unnormalized).
PyTorch torch.fft.rfft. Input is real [batch, n], output is
complex [batch, n/2 + 1] (Hermitian-half).
Irfft = 3
y = IRFFT(x, n) — complex-to-real inverse transform, normalized
by 1/N. PyTorch torch.fft.irfft. Input is complex
[batch, n/2 + 1], output is real [batch, n]. The output
length n is a required descriptor parameter (cannot be inferred
from the Hermitian-half input shape — both 2*(n/2) and
2*(n/2)+1 map to the same Hermitian-half length).
FftShift = 4
fftshift — shift the zero-frequency component to the center of
the spectrum. PyTorch torch.fft.fftshift (matches NumPy’s
np.fft.fftshift).
Equivalent to roll(x, n // 2), giving:
y[i] = x[(i - n // 2) mod n] = x[(i + (n+1) // 2) mod n].
Bit-exact (pure index permutation, no arithmetic on values).
IfftShift = 5
ifftshift — true inverse of fftshift:
ifftshift(fftshift(x)) == x for any n. PyTorch
torch.fft.ifftshift.
Equivalent to roll(x, -(n // 2)), giving:
y[i] = x[(i + n // 2) mod n].
For even n this is identical to fftshift (the n/2 offset
is self-inverse mod n); for odd n the two cyclic offsets
differ by one cell. Bit-exact.