Skip to main content

baracuda_kernels/fft/
fftshift.rs

1//! `fftshift` / `ifftshift` — bespoke index-permutation kernels.
2//!
3//! cuFFT has no native shift, so these are hand-rolled kernels in
4//! `baracuda-kernels-sys`. The kernel is element-width-generic (4 / 8
5//! / 16-byte cells) so the same code covers `f32`, `f64`,
6//! [`Complex32`], and [`Complex64`] without per-type templating —
7//! shift is a pure index permutation (no arithmetic on the element
8//! values), so the element type is irrelevant beyond its byte width.
9//!
10//! 1-D shifts along the last axis of a `[batch, n]` tensor — matches
11//! NumPy / PyTorch convention:
12//!
13//! - `fftshift`:  `y[b, i] = x[b, (i + (n+1)/2) % n]` (equiv. roll(x, n//2))
14//! - `ifftshift`: `y[b, i] = x[b, (i + n/2)     % n]` (equiv. roll(x, -(n//2)))
15//!
16//! For even `n` the two are identical (the `n/2` offset is self-
17//! inverse mod `n`). For odd `n` the cyclic offsets differ by one
18//! cell and `ifftshift` is the genuine inverse of `fftshift`
19//! (`ifftshift(fftshift(x)) == x` for any `n`).
20
21use core::ffi::c_void;
22use core::marker::PhantomData;
23
24use baracuda_cutlass::{Error, Result};
25use baracuda_driver::Stream;
26use baracuda_kernels_sys::{
27    baracuda_kernels_fftshift_16_run, baracuda_kernels_fftshift_4_run,
28    baracuda_kernels_fftshift_8_run, baracuda_kernels_ifftshift_16_run,
29    baracuda_kernels_ifftshift_4_run, baracuda_kernels_ifftshift_8_run,
30};
31use baracuda_kernels_types::{
32    ArchSku, BackendKind, Element, ElementKind, FftKind, KernelSku, MathPrecision, OpCategory,
33    PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
34};
35
36use super::fft::map_status;
37
38/// Descriptor for an fftshift / ifftshift op.
39#[derive(Copy, Clone, Debug)]
40pub struct FftShiftDescriptor {
41    /// Length of the last axis (the axis being shifted). For 1-D
42    /// fftshift `[batch, n]` this is the size of the shifted axis.
43    pub n: i32,
44    /// Number of independent rows. Each row is shifted independently.
45    pub batch: i32,
46    /// `true` selects `ifftshift` (cyclic offset `n/2`), `false`
47    /// selects `fftshift` (cyclic offset `(n+1)/2`). Identical for
48    /// even `n`; the two diverge for odd `n` and `ifftshift` is then
49    /// the true inverse of `fftshift`.
50    pub inverse: bool,
51    /// Element type. Any [`Element`]; the kernel dispatches on
52    /// `size_of::<T>()` (4 / 8 / 16 bytes).
53    pub element: ElementKind,
54}
55
56/// Args bundle for an fftshift / ifftshift.
57pub struct FftShiftArgs<'a, T: Element> {
58    /// Input tensor `[batch, n]`.
59    pub x: TensorRef<'a, T, 2>,
60    /// Output tensor `[batch, n]`. Must be distinct from `x` (the
61    /// kernel reads from `x` and writes to `y` without scratch — in-
62    /// place shift would require a 2-phase swap, which the trailblazer
63    /// doesn't ship).
64    pub y: TensorMut<'a, T, 2>,
65}
66
67/// 1-D `fftshift` / `ifftshift` plan — bespoke index-permutation
68/// kernel.
69///
70/// Cyclically shifts the last axis of a `[batch, n]` tensor by `n/2`
71/// (ifftshift) or `(n+1)/2` (fftshift). Matches NumPy / PyTorch
72/// conventions. For even `n` the two directions are identical; for odd
73/// `n` `ifftshift` is the genuine inverse of `fftshift`.
74///
75/// **When to use**: place the DC component at the centre of an FFT
76/// output (or vice versa). Use [`super::FftShiftNdPlan`] for shifts
77/// over multiple axes.
78///
79/// **Dtypes**: any [`Element`] — kernel dispatches on
80/// `size_of::<T>()` (4 / 8 / 16-byte cells), so `f32`, `f64`,
81/// `Complex32`, `Complex64` all work without per-type templating.
82///
83/// **Shape**: `[batch, n]`. Out-of-place only (in-place shift would
84/// need a 2-phase swap).
85///
86/// **Workspace**: zero.
87///
88/// **Precision guarantee**: bit-exact (pure index permutation, no
89/// arithmetic).
90///
91/// No cuFFT handle / state — the plan is just configuration.
92pub struct FftShiftPlan<T: Element> {
93    desc: FftShiftDescriptor,
94    sku: KernelSku,
95    _marker: PhantomData<T>,
96}
97
98impl<T: Element> FftShiftPlan<T> {
99    /// Pick a kernel + validate the descriptor.
100    pub fn select(
101        _stream: &Stream,
102        desc: &FftShiftDescriptor,
103        _pref: PlanPreference,
104    ) -> Result<Self> {
105        if desc.element != T::KIND {
106            return Err(Error::Unsupported(
107                "baracuda-kernels::FftShiftPlan: descriptor.element != T::KIND",
108            ));
109        }
110        // Kernel handles cells of 4, 8, or 16 bytes — any baracuda
111        // [`Element`] with one of those widths is supported. Today
112        // that's f32 / f64 / Complex32 / Complex64; rejecting the
113        // others up front keeps the supported set narrow and obvious.
114        let size = core::mem::size_of::<T>();
115        if !matches!(size, 4 | 8 | 16) {
116            return Err(Error::Unsupported(
117                "baracuda-kernels::FftShiftPlan: only 4/8/16-byte element types supported",
118            ));
119        }
120        if desc.n < 0 {
121            return Err(Error::InvalidProblem(
122                "baracuda-kernels::FftShiftPlan: n must be >= 0",
123            ));
124        }
125        if desc.batch < 0 {
126            return Err(Error::InvalidProblem(
127                "baracuda-kernels::FftShiftPlan: batch must be >= 0",
128            ));
129        }
130
131        let math_precision = match T::KIND {
132            ElementKind::F64 | ElementKind::Complex64 => MathPrecision::F64,
133            _ => MathPrecision::F32,
134        };
135        let precision_guarantee = PrecisionGuarantee {
136            math_precision,
137            accumulator: T::KIND,
138            // Pure index permutation — bit-exact, no arithmetic.
139            bit_stable_on_same_hardware: true,
140            deterministic: true,
141        };
142        let op = if desc.inverse {
143            FftKind::IfftShift
144        } else {
145            FftKind::FftShift
146        };
147        let sku = KernelSku {
148            category: OpCategory::Fft,
149            op: op as u16,
150            element: T::KIND,
151            aux_element: None,
152            layout: None,
153            epilogue: None,
154            arch: ArchSku::Sm80,
155            backend: BackendKind::Bespoke,
156            precision_guarantee,
157        };
158
159        Ok(Self {
160            desc: *desc,
161            sku,
162            _marker: PhantomData,
163        })
164    }
165
166    /// Kernel SKU identity.
167    #[inline]
168    pub fn sku(&self) -> KernelSku {
169        self.sku
170    }
171
172    /// Numerical guarantees.
173    #[inline]
174    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
175        self.sku.precision_guarantee
176    }
177
178    /// Workspace size in bytes.
179    #[inline]
180    pub fn workspace_size(&self) -> usize {
181        0
182    }
183
184    /// Run the fftshift / ifftshift.
185    pub fn run(
186        &self,
187        stream: &Stream,
188        _workspace: Workspace<'_>,
189        args: FftShiftArgs<'_, T>,
190    ) -> Result<()> {
191        let expected = [self.desc.batch, self.desc.n];
192        if args.x.shape != expected {
193            return Err(Error::InvalidProblem(
194                "baracuda-kernels::FftShiftPlan: x shape != [batch, n]",
195            ));
196        }
197        if args.y.shape != expected {
198            return Err(Error::InvalidProblem(
199                "baracuda-kernels::FftShiftPlan: y shape != [batch, n]",
200            ));
201        }
202        let numel = (self.desc.batch as i64) * (self.desc.n as i64);
203        if (args.x.data.len() as i64) < numel {
204            return Err(Error::BufferTooSmall {
205                needed: numel as usize,
206                got: args.x.data.len(),
207            });
208        }
209        if (args.y.data.len() as i64) < numel {
210            return Err(Error::BufferTooSmall {
211                needed: numel as usize,
212                got: args.y.data.len(),
213            });
214        }
215        if numel == 0 {
216            return Ok(());
217        }
218
219        let x_ptr = args.x.data.as_raw().0 as *const c_void;
220        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
221        let stream_ptr = stream.as_raw() as *mut c_void;
222        let batch = self.desc.batch as i64;
223        let n = self.desc.n;
224
225        let size = core::mem::size_of::<T>();
226        let status = unsafe {
227            match (size, self.desc.inverse) {
228                (4, false) => baracuda_kernels_fftshift_4_run(
229                    batch, n, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
230                ),
231                (4, true) => baracuda_kernels_ifftshift_4_run(
232                    batch, n, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
233                ),
234                (8, false) => baracuda_kernels_fftshift_8_run(
235                    batch, n, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
236                ),
237                (8, true) => baracuda_kernels_ifftshift_8_run(
238                    batch, n, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
239                ),
240                (16, false) => baracuda_kernels_fftshift_16_run(
241                    batch, n, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
242                ),
243                (16, true) => baracuda_kernels_ifftshift_16_run(
244                    batch, n, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
245                ),
246                _ => unreachable!("select() gates on size_of::<T>() in 4 / 8 / 16"),
247            }
248        };
249        map_status(status)
250    }
251}