baracuda_kernels/fft/fftshift.rs
1//! `fftshift` / `ifftshift` — bespoke index-permutation kernels.
2//!
3//! cuFFT has no native shift, so these are hand-rolled kernels in
4//! `baracuda-kernels-sys`. The kernel is element-width-generic (4 / 8
5//! / 16-byte cells) so the same code covers `f32`, `f64`,
6//! [`Complex32`], and [`Complex64`] without per-type templating —
7//! shift is a pure index permutation (no arithmetic on the element
8//! values), so the element type is irrelevant beyond its byte width.
9//!
10//! 1-D shifts along the last axis of a `[batch, n]` tensor — matches
11//! NumPy / PyTorch convention:
12//!
13//! - `fftshift`: `y[b, i] = x[b, (i + (n+1)/2) % n]` (equiv. roll(x, n//2))
14//! - `ifftshift`: `y[b, i] = x[b, (i + n/2) % n]` (equiv. roll(x, -(n//2)))
15//!
16//! For even `n` the two are identical (the `n/2` offset is self-
17//! inverse mod `n`). For odd `n` the cyclic offsets differ by one
18//! cell and `ifftshift` is the genuine inverse of `fftshift`
19//! (`ifftshift(fftshift(x)) == x` for any `n`).
20
21use core::ffi::c_void;
22use core::marker::PhantomData;
23
24use baracuda_cutlass::{Error, Result};
25use baracuda_driver::Stream;
26use baracuda_kernels_sys::{
27 baracuda_kernels_fftshift_16_run, baracuda_kernels_fftshift_4_run,
28 baracuda_kernels_fftshift_8_run, baracuda_kernels_ifftshift_16_run,
29 baracuda_kernels_ifftshift_4_run, baracuda_kernels_ifftshift_8_run,
30};
31use baracuda_kernels_types::{
32 ArchSku, BackendKind, Element, ElementKind, FftKind, KernelSku, MathPrecision, OpCategory,
33 PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
34};
35
36use super::fft::map_status;
37
38/// Descriptor for an fftshift / ifftshift op.
39#[derive(Copy, Clone, Debug)]
40pub struct FftShiftDescriptor {
41 /// Length of the last axis (the axis being shifted). For 1-D
42 /// fftshift `[batch, n]` this is the size of the shifted axis.
43 pub n: i32,
44 /// Number of independent rows. Each row is shifted independently.
45 pub batch: i32,
46 /// `true` selects `ifftshift` (cyclic offset `n/2`), `false`
47 /// selects `fftshift` (cyclic offset `(n+1)/2`). Identical for
48 /// even `n`; the two diverge for odd `n` and `ifftshift` is then
49 /// the true inverse of `fftshift`.
50 pub inverse: bool,
51 /// Element type. Any [`Element`]; the kernel dispatches on
52 /// `size_of::<T>()` (4 / 8 / 16 bytes).
53 pub element: ElementKind,
54}
55
56/// Args bundle for an fftshift / ifftshift.
57pub struct FftShiftArgs<'a, T: Element> {
58 /// Input tensor `[batch, n]`.
59 pub x: TensorRef<'a, T, 2>,
60 /// Output tensor `[batch, n]`. Must be distinct from `x` (the
61 /// kernel reads from `x` and writes to `y` without scratch — in-
62 /// place shift would require a 2-phase swap, which the trailblazer
63 /// doesn't ship).
64 pub y: TensorMut<'a, T, 2>,
65}
66
67/// 1-D `fftshift` / `ifftshift` plan — bespoke index-permutation
68/// kernel.
69///
70/// Cyclically shifts the last axis of a `[batch, n]` tensor by `n/2`
71/// (ifftshift) or `(n+1)/2` (fftshift). Matches NumPy / PyTorch
72/// conventions. For even `n` the two directions are identical; for odd
73/// `n` `ifftshift` is the genuine inverse of `fftshift`.
74///
75/// **When to use**: place the DC component at the centre of an FFT
76/// output (or vice versa). Use [`super::FftShiftNdPlan`] for shifts
77/// over multiple axes.
78///
79/// **Dtypes**: any [`Element`] — kernel dispatches on
80/// `size_of::<T>()` (4 / 8 / 16-byte cells), so `f32`, `f64`,
81/// `Complex32`, `Complex64` all work without per-type templating.
82///
83/// **Shape**: `[batch, n]`. Out-of-place only (in-place shift would
84/// need a 2-phase swap).
85///
86/// **Workspace**: zero.
87///
88/// **Precision guarantee**: bit-exact (pure index permutation, no
89/// arithmetic).
90///
91/// No cuFFT handle / state — the plan is just configuration.
92pub struct FftShiftPlan<T: Element> {
93 desc: FftShiftDescriptor,
94 sku: KernelSku,
95 _marker: PhantomData<T>,
96}
97
98impl<T: Element> FftShiftPlan<T> {
99 /// Pick a kernel + validate the descriptor.
100 pub fn select(
101 _stream: &Stream,
102 desc: &FftShiftDescriptor,
103 _pref: PlanPreference,
104 ) -> Result<Self> {
105 if desc.element != T::KIND {
106 return Err(Error::Unsupported(
107 "baracuda-kernels::FftShiftPlan: descriptor.element != T::KIND",
108 ));
109 }
110 // Kernel handles cells of 4, 8, or 16 bytes — any baracuda
111 // [`Element`] with one of those widths is supported. Today
112 // that's f32 / f64 / Complex32 / Complex64; rejecting the
113 // others up front keeps the supported set narrow and obvious.
114 let size = core::mem::size_of::<T>();
115 if !matches!(size, 4 | 8 | 16) {
116 return Err(Error::Unsupported(
117 "baracuda-kernels::FftShiftPlan: only 4/8/16-byte element types supported",
118 ));
119 }
120 if desc.n < 0 {
121 return Err(Error::InvalidProblem(
122 "baracuda-kernels::FftShiftPlan: n must be >= 0",
123 ));
124 }
125 if desc.batch < 0 {
126 return Err(Error::InvalidProblem(
127 "baracuda-kernels::FftShiftPlan: batch must be >= 0",
128 ));
129 }
130
131 let math_precision = match T::KIND {
132 ElementKind::F64 | ElementKind::Complex64 => MathPrecision::F64,
133 _ => MathPrecision::F32,
134 };
135 let precision_guarantee = PrecisionGuarantee {
136 math_precision,
137 accumulator: T::KIND,
138 // Pure index permutation — bit-exact, no arithmetic.
139 bit_stable_on_same_hardware: true,
140 deterministic: true,
141 };
142 let op = if desc.inverse {
143 FftKind::IfftShift
144 } else {
145 FftKind::FftShift
146 };
147 let sku = KernelSku {
148 category: OpCategory::Fft,
149 op: op as u16,
150 element: T::KIND,
151 aux_element: None,
152 layout: None,
153 epilogue: None,
154 arch: ArchSku::Sm80,
155 backend: BackendKind::Bespoke,
156 precision_guarantee,
157 };
158
159 Ok(Self {
160 desc: *desc,
161 sku,
162 _marker: PhantomData,
163 })
164 }
165
166 /// Kernel SKU identity.
167 #[inline]
168 pub fn sku(&self) -> KernelSku {
169 self.sku
170 }
171
172 /// Numerical guarantees.
173 #[inline]
174 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
175 self.sku.precision_guarantee
176 }
177
178 /// Workspace size in bytes.
179 #[inline]
180 pub fn workspace_size(&self) -> usize {
181 0
182 }
183
184 /// Run the fftshift / ifftshift.
185 pub fn run(
186 &self,
187 stream: &Stream,
188 _workspace: Workspace<'_>,
189 args: FftShiftArgs<'_, T>,
190 ) -> Result<()> {
191 let expected = [self.desc.batch, self.desc.n];
192 if args.x.shape != expected {
193 return Err(Error::InvalidProblem(
194 "baracuda-kernels::FftShiftPlan: x shape != [batch, n]",
195 ));
196 }
197 if args.y.shape != expected {
198 return Err(Error::InvalidProblem(
199 "baracuda-kernels::FftShiftPlan: y shape != [batch, n]",
200 ));
201 }
202 let numel = (self.desc.batch as i64) * (self.desc.n as i64);
203 if (args.x.data.len() as i64) < numel {
204 return Err(Error::BufferTooSmall {
205 needed: numel as usize,
206 got: args.x.data.len(),
207 });
208 }
209 if (args.y.data.len() as i64) < numel {
210 return Err(Error::BufferTooSmall {
211 needed: numel as usize,
212 got: args.y.data.len(),
213 });
214 }
215 if numel == 0 {
216 return Ok(());
217 }
218
219 let x_ptr = args.x.data.as_raw().0 as *const c_void;
220 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
221 let stream_ptr = stream.as_raw() as *mut c_void;
222 let batch = self.desc.batch as i64;
223 let n = self.desc.n;
224
225 let size = core::mem::size_of::<T>();
226 let status = unsafe {
227 match (size, self.desc.inverse) {
228 (4, false) => baracuda_kernels_fftshift_4_run(
229 batch, n, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
230 ),
231 (4, true) => baracuda_kernels_ifftshift_4_run(
232 batch, n, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
233 ),
234 (8, false) => baracuda_kernels_fftshift_8_run(
235 batch, n, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
236 ),
237 (8, true) => baracuda_kernels_ifftshift_8_run(
238 batch, n, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
239 ),
240 (16, false) => baracuda_kernels_fftshift_16_run(
241 batch, n, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
242 ),
243 (16, true) => baracuda_kernels_ifftshift_16_run(
244 batch, n, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
245 ),
246 _ => unreachable!("select() gates on size_of::<T>() in 4 / 8 / 16"),
247 }
248 };
249 map_status(status)
250 }
251}