Skip to main content

baracuda_kernels/fft/
rfft.rs

1//! RFFT (real-to-complex) and IRFFT (complex-to-real) — `RfftPlan<T>`
2//! and `IrfftPlan<T>` for `T = f32` or `T = f64`.
3//!
4//! Wraps cuFFT's `cufftExecR2C` / `cufftExecD2Z` (forward) and
5//! `cufftExecC2R` / `cufftExecZ2D` (inverse).
6//!
7//! Input / output shape contract:
8//! - RFFT: input `[batch, n]` real, output `[batch, n/2 + 1]` complex
9//!   (Hermitian-half — the missing half is the conjugate of the
10//!   present half).
11//! - IRFFT: input `[batch, n/2 + 1]` complex, output `[batch, n]`
12//!   real. cuFFT cannot infer the output length `n` from the Hermitian-
13//!   half input alone (both `2*(n/2)` and `2*(n/2)+1` map to the same
14//!   half), so `n` is a required descriptor parameter.
15//!
16//! For inverse transforms the plan applies `1/n` normalization
17//! in-place after the cuFFT exec, matching PyTorch's `norm="backward"`.
18
19use core::cell::Cell;
20use core::ffi::c_void;
21use core::marker::PhantomData;
22
23use baracuda_cutlass::{Error, Result};
24use baracuda_driver::Stream;
25use baracuda_kernels_sys::{
26    baracuda_kernels_scale_inplace_real_f32_run, baracuda_kernels_scale_inplace_real_f64_run,
27    cufftComplex, cufftDestroy, cufftDoubleComplex, cufftExecC2R, cufftExecD2Z, cufftExecR2C,
28    cufftExecZ2D, cufftHandle, cufftPlan1d, cufftSetStream, CUFFT_C2R, CUFFT_D2Z, CUFFT_R2C,
29    CUFFT_Z2D,
30};
31use baracuda_kernels_types::{
32    ArchSku, BackendKind, Complex32, Complex64, Element, ElementKind, FftKind, KernelSku,
33    MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
34};
35
36use super::fft::{cufft_to_status, map_status};
37
38const HANDLE_UNINIT: cufftHandle = -1;
39
40// =============================================================================
41// RFFT — real → complex (Hermitian-half)
42// =============================================================================
43
44/// Descriptor for an RFFT (real-to-complex) op.
45#[derive(Copy, Clone, Debug)]
46pub struct RfftDescriptor {
47    /// Signal length (the real input length). Output has shape
48    /// `[batch, n/2 + 1]` (Hermitian-half).
49    pub n: i32,
50    /// Number of independent transforms in one launch.
51    pub batch: i32,
52    /// Real-side element type — `F32` or `F64`. The complex output
53    /// type is `Complex32` (for `F32`) or `Complex64` (for `F64`).
54    pub element: ElementKind,
55}
56
57/// Args bundle for an RFFT.
58///
59/// `T` is the *real* element type (`f32` / `f64`). The complex output
60/// uses [`Complex32`] / [`Complex64`] depending on `T`.
61pub struct RfftArgs<'a, T: Element, C: Element> {
62    /// Real input tensor `[batch, n]`.
63    pub x: TensorRef<'a, T, 2>,
64    /// Complex output tensor `[batch, n/2 + 1]`.
65    pub y: TensorMut<'a, C, 2>,
66}
67
68/// 1-D RFFT plan — real input → Hermitian-half complex output.
69///
70/// Wraps cuFFT's `cufftExecR2C` (`f32`) / `cufftExecD2Z` (`f64`).
71///
72/// **When to use**: forward FFT of real-valued data; the output is the
73/// non-redundant Hermitian half. Pair with [`IrfftPlan`] for the
74/// inverse direction. Use [`super::FftPlan`] when the input is
75/// already complex.
76///
77/// **Dtypes**: `f32` → `Complex32`; `f64` → `Complex64`.
78///
79/// **Shape**: real `[batch, n]` → complex `[batch, n/2 + 1]`.
80///
81/// **Normalization**: unnormalized (`norm="backward"`).
82///
83/// **Workspace**: zero — cuFFT manages internal workspace.
84///
85/// **Precision guarantee**: deterministic; not bit-stable across
86/// cuFFT versions.
87///
88/// Owns a lazy cuFFT handle (`!Sync` / `!Send`); destroyed on `Drop`.
89pub struct RfftPlan<T: Element> {
90    desc: RfftDescriptor,
91    sku: KernelSku,
92    handle: Cell<cufftHandle>,
93    _marker: PhantomData<T>,
94}
95
96impl<T: Element> RfftPlan<T> {
97    /// Pick a kernel + validate the descriptor.
98    pub fn select(_stream: &Stream, desc: &RfftDescriptor, _pref: PlanPreference) -> Result<Self> {
99        if desc.element != T::KIND {
100            return Err(Error::Unsupported(
101                "baracuda-kernels::RfftPlan: descriptor.element != T::KIND",
102            ));
103        }
104        if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
105            return Err(Error::Unsupported(
106                "baracuda-kernels::RfftPlan: R2C FFT supports f32 + f64 only",
107            ));
108        }
109        if desc.n <= 0 {
110            return Err(Error::InvalidProblem(
111                "baracuda-kernels::RfftPlan: n must be > 0",
112            ));
113        }
114        if desc.batch <= 0 {
115            return Err(Error::InvalidProblem(
116                "baracuda-kernels::RfftPlan: batch must be > 0",
117            ));
118        }
119        let math_precision = match T::KIND {
120            ElementKind::F64 => MathPrecision::F64,
121            _ => MathPrecision::F32,
122        };
123        let aux = match T::KIND {
124            ElementKind::F32 => Some(ElementKind::Complex32),
125            ElementKind::F64 => Some(ElementKind::Complex64),
126            _ => None,
127        };
128        let precision_guarantee = PrecisionGuarantee {
129            math_precision,
130            accumulator: T::KIND,
131            bit_stable_on_same_hardware: false,
132            deterministic: true,
133        };
134        let sku = KernelSku {
135            category: OpCategory::Fft,
136            op: FftKind::Rfft as u16,
137            element: T::KIND,
138            aux_element: aux,
139            layout: None,
140            epilogue: None,
141            arch: ArchSku::Sm80,
142            backend: BackendKind::Cufft,
143            precision_guarantee,
144        };
145        Ok(Self {
146            desc: *desc,
147            sku,
148            handle: Cell::new(HANDLE_UNINIT),
149            _marker: PhantomData,
150        })
151    }
152
153    /// Kernel SKU identity.
154    #[inline]
155    pub fn sku(&self) -> KernelSku {
156        self.sku
157    }
158
159    /// Numerical guarantees.
160    #[inline]
161    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
162        self.sku.precision_guarantee
163    }
164
165    /// Workspace size in bytes — cuFFT-internal, no caller-supplied
166    /// workspace needed.
167    #[inline]
168    pub fn workspace_size(&self) -> usize {
169        0
170    }
171
172    fn ensure_handle(&self) -> Result<cufftHandle> {
173        let h = self.handle.get();
174        if h != HANDLE_UNINIT {
175            return Ok(h);
176        }
177        let fft_type = match T::KIND {
178            ElementKind::F32 => CUFFT_R2C,
179            ElementKind::F64 => CUFFT_D2Z,
180            _ => unreachable!("select() gates on F32 / F64"),
181        };
182        let mut handle: cufftHandle = HANDLE_UNINIT;
183        let status = unsafe {
184            cufftPlan1d(
185                &mut handle as *mut _,
186                self.desc.n,
187                fft_type,
188                self.desc.batch,
189            )
190        };
191        if status != 0 {
192            return Err(Error::CutlassInternal(cufft_to_status(status)));
193        }
194        self.handle.set(handle);
195        Ok(handle)
196    }
197
198    fn bind_stream(&self, handle: cufftHandle, stream: &Stream) -> Result<()> {
199        let stream_ptr = stream.as_raw() as *mut c_void;
200        let status = unsafe { cufftSetStream(handle, stream_ptr) };
201        if status != 0 {
202            return Err(Error::CutlassInternal(cufft_to_status(status)));
203        }
204        Ok(())
205    }
206}
207
208impl RfftPlan<f32> {
209    /// Run the R2C FFT (single precision).
210    pub fn run(
211        &self,
212        stream: &Stream,
213        _workspace: Workspace<'_>,
214        args: RfftArgs<'_, f32, Complex32>,
215    ) -> Result<()> {
216        let n = self.desc.n;
217        let batch = self.desc.batch;
218        let in_shape = [batch, n];
219        let out_shape = [batch, n / 2 + 1];
220        if args.x.shape != in_shape {
221            return Err(Error::InvalidProblem(
222                "baracuda-kernels::RfftPlan<f32>: x shape != [batch, n]",
223            ));
224        }
225        if args.y.shape != out_shape {
226            return Err(Error::InvalidProblem(
227                "baracuda-kernels::RfftPlan<f32>: y shape != [batch, n/2 + 1]",
228            ));
229        }
230        let in_numel = (batch as i64) * (n as i64);
231        let out_numel = (batch as i64) * ((n / 2 + 1) as i64);
232        if (args.x.data.len() as i64) < in_numel {
233            return Err(Error::BufferTooSmall {
234                needed: in_numel as usize,
235                got: args.x.data.len(),
236            });
237        }
238        if (args.y.data.len() as i64) < out_numel {
239            return Err(Error::BufferTooSmall {
240                needed: out_numel as usize,
241                got: args.y.data.len(),
242            });
243        }
244        if in_numel == 0 {
245            return Ok(());
246        }
247
248        let handle = self.ensure_handle()?;
249        self.bind_stream(handle, stream)?;
250
251        let idata = args.x.data.as_raw().0 as *mut f32;
252        let odata = args.y.data.as_raw().0 as *mut cufftComplex;
253        let status = unsafe { cufftExecR2C(handle, idata, odata) };
254        if status != 0 {
255            return Err(Error::CutlassInternal(cufft_to_status(status)));
256        }
257        Ok(())
258    }
259}
260
261impl RfftPlan<f64> {
262    /// Run the R2C FFT (double precision).
263    pub fn run(
264        &self,
265        stream: &Stream,
266        _workspace: Workspace<'_>,
267        args: RfftArgs<'_, f64, Complex64>,
268    ) -> Result<()> {
269        let n = self.desc.n;
270        let batch = self.desc.batch;
271        let in_shape = [batch, n];
272        let out_shape = [batch, n / 2 + 1];
273        if args.x.shape != in_shape {
274            return Err(Error::InvalidProblem(
275                "baracuda-kernels::RfftPlan<f64>: x shape != [batch, n]",
276            ));
277        }
278        if args.y.shape != out_shape {
279            return Err(Error::InvalidProblem(
280                "baracuda-kernels::RfftPlan<f64>: y shape != [batch, n/2 + 1]",
281            ));
282        }
283        let in_numel = (batch as i64) * (n as i64);
284        let out_numel = (batch as i64) * ((n / 2 + 1) as i64);
285        if (args.x.data.len() as i64) < in_numel {
286            return Err(Error::BufferTooSmall {
287                needed: in_numel as usize,
288                got: args.x.data.len(),
289            });
290        }
291        if (args.y.data.len() as i64) < out_numel {
292            return Err(Error::BufferTooSmall {
293                needed: out_numel as usize,
294                got: args.y.data.len(),
295            });
296        }
297        if in_numel == 0 {
298            return Ok(());
299        }
300
301        let handle = self.ensure_handle()?;
302        self.bind_stream(handle, stream)?;
303
304        let idata = args.x.data.as_raw().0 as *mut f64;
305        let odata = args.y.data.as_raw().0 as *mut cufftDoubleComplex;
306        let status = unsafe { cufftExecD2Z(handle, idata, odata) };
307        if status != 0 {
308            return Err(Error::CutlassInternal(cufft_to_status(status)));
309        }
310        Ok(())
311    }
312}
313
314impl<T: Element> Drop for RfftPlan<T> {
315    fn drop(&mut self) {
316        let h = self.handle.get();
317        if h != HANDLE_UNINIT {
318            unsafe {
319                let _ = cufftDestroy(h);
320            }
321            self.handle.set(HANDLE_UNINIT);
322        }
323    }
324}
325
326// =============================================================================
327// IRFFT — complex (Hermitian-half) → real
328// =============================================================================
329
330/// Descriptor for an IRFFT (complex-to-real) op.
331///
332/// Note: cuFFT cannot infer the output length `n` from the Hermitian-
333/// half input alone (both `2 * (n/2)` and `2 * (n/2) + 1` produce
334/// inputs of length `n/2 + 1`). The descriptor carries `n` explicitly;
335/// the input shape is then `[batch, n/2 + 1]` and output is
336/// `[batch, n]`.
337#[derive(Copy, Clone, Debug)]
338pub struct IrfftDescriptor {
339    /// Real output length. Input shape is `[batch, n/2 + 1]`.
340    pub n: i32,
341    /// Number of independent transforms in one launch.
342    pub batch: i32,
343    /// Real-side element type (output dtype) — `F32` or `F64`.
344    pub element: ElementKind,
345}
346
347/// Args bundle for an IRFFT.
348///
349/// `T` is the *real* output type; `C` is the matching complex input
350/// type ([`Complex32`] for `f32`, [`Complex64`] for `f64`).
351pub struct IrfftArgs<'a, T: Element, C: Element> {
352    /// Complex input tensor `[batch, n/2 + 1]`.
353    pub x: TensorRef<'a, C, 2>,
354    /// Real output tensor `[batch, n]`.
355    pub y: TensorMut<'a, T, 2>,
356}
357
358/// 1-D IRFFT plan — Hermitian-half complex input → real output.
359///
360/// Wraps cuFFT's `cufftExecC2R` (`f32`) / `cufftExecZ2D` (`f64`),
361/// followed by a `scale_inplace_real_*` launch that applies the
362/// `1/n` normalization (PyTorch `norm="backward"`).
363///
364/// **When to use**: inverse FFT producing real-valued data. The
365/// caller supplies `n` explicitly because the Hermitian-half input
366/// length is ambiguous between `2 * (n/2)` and `2 * (n/2) + 1`.
367///
368/// **Dtypes**: `Complex32` → `f32`; `Complex64` → `f64`.
369///
370/// **Shape**: complex `[batch, n/2 + 1]` → real `[batch, n]`.
371///
372/// **Normalization**: normalized by `1/n` to match PyTorch's
373/// `norm="backward"`.
374///
375/// **Workspace**: zero — cuFFT manages internal workspace.
376///
377/// **Precision guarantee**: deterministic; not bit-stable across
378/// cuFFT versions.
379///
380/// Owns a lazy cuFFT handle (`!Sync` / `!Send`); destroyed on `Drop`.
381pub struct IrfftPlan<T: Element> {
382    desc: IrfftDescriptor,
383    sku: KernelSku,
384    handle: Cell<cufftHandle>,
385    _marker: PhantomData<T>,
386}
387
388impl<T: Element> IrfftPlan<T> {
389    /// Pick a kernel + validate the descriptor.
390    pub fn select(
391        _stream: &Stream,
392        desc: &IrfftDescriptor,
393        _pref: PlanPreference,
394    ) -> Result<Self> {
395        if desc.element != T::KIND {
396            return Err(Error::Unsupported(
397                "baracuda-kernels::IrfftPlan: descriptor.element != T::KIND",
398            ));
399        }
400        if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
401            return Err(Error::Unsupported(
402                "baracuda-kernels::IrfftPlan: C2R FFT supports f32 + f64 only",
403            ));
404        }
405        if desc.n <= 0 {
406            return Err(Error::InvalidProblem(
407                "baracuda-kernels::IrfftPlan: n must be > 0",
408            ));
409        }
410        if desc.batch <= 0 {
411            return Err(Error::InvalidProblem(
412                "baracuda-kernels::IrfftPlan: batch must be > 0",
413            ));
414        }
415        let math_precision = match T::KIND {
416            ElementKind::F64 => MathPrecision::F64,
417            _ => MathPrecision::F32,
418        };
419        let aux = match T::KIND {
420            ElementKind::F32 => Some(ElementKind::Complex32),
421            ElementKind::F64 => Some(ElementKind::Complex64),
422            _ => None,
423        };
424        let precision_guarantee = PrecisionGuarantee {
425            math_precision,
426            accumulator: T::KIND,
427            bit_stable_on_same_hardware: false,
428            deterministic: true,
429        };
430        let sku = KernelSku {
431            category: OpCategory::Fft,
432            op: FftKind::Irfft as u16,
433            element: T::KIND,
434            aux_element: aux,
435            layout: None,
436            epilogue: None,
437            arch: ArchSku::Sm80,
438            backend: BackendKind::Cufft,
439            precision_guarantee,
440        };
441        Ok(Self {
442            desc: *desc,
443            sku,
444            handle: Cell::new(HANDLE_UNINIT),
445            _marker: PhantomData,
446        })
447    }
448
449    /// Kernel SKU identity.
450    #[inline]
451    pub fn sku(&self) -> KernelSku {
452        self.sku
453    }
454
455    /// Numerical guarantees.
456    #[inline]
457    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
458        self.sku.precision_guarantee
459    }
460
461    /// Workspace size in bytes.
462    #[inline]
463    pub fn workspace_size(&self) -> usize {
464        0
465    }
466
467    fn ensure_handle(&self) -> Result<cufftHandle> {
468        let h = self.handle.get();
469        if h != HANDLE_UNINIT {
470            return Ok(h);
471        }
472        let fft_type = match T::KIND {
473            ElementKind::F32 => CUFFT_C2R,
474            ElementKind::F64 => CUFFT_Z2D,
475            _ => unreachable!("select() gates on F32 / F64"),
476        };
477        let mut handle: cufftHandle = HANDLE_UNINIT;
478        let status = unsafe {
479            cufftPlan1d(
480                &mut handle as *mut _,
481                self.desc.n,
482                fft_type,
483                self.desc.batch,
484            )
485        };
486        if status != 0 {
487            return Err(Error::CutlassInternal(cufft_to_status(status)));
488        }
489        self.handle.set(handle);
490        Ok(handle)
491    }
492
493    fn bind_stream(&self, handle: cufftHandle, stream: &Stream) -> Result<()> {
494        let stream_ptr = stream.as_raw() as *mut c_void;
495        let status = unsafe { cufftSetStream(handle, stream_ptr) };
496        if status != 0 {
497            return Err(Error::CutlassInternal(cufft_to_status(status)));
498        }
499        Ok(())
500    }
501}
502
503impl IrfftPlan<f32> {
504    /// Run the C2R FFT (single precision). Applies `1/n` normalization
505    /// to the output to match PyTorch's `norm="backward"`.
506    pub fn run(
507        &self,
508        stream: &Stream,
509        _workspace: Workspace<'_>,
510        args: IrfftArgs<'_, f32, Complex32>,
511    ) -> Result<()> {
512        let n = self.desc.n;
513        let batch = self.desc.batch;
514        let in_shape = [batch, n / 2 + 1];
515        let out_shape = [batch, n];
516        if args.x.shape != in_shape {
517            return Err(Error::InvalidProblem(
518                "baracuda-kernels::IrfftPlan<f32>: x shape != [batch, n/2 + 1]",
519            ));
520        }
521        if args.y.shape != out_shape {
522            return Err(Error::InvalidProblem(
523                "baracuda-kernels::IrfftPlan<f32>: y shape != [batch, n]",
524            ));
525        }
526        let in_numel = (batch as i64) * ((n / 2 + 1) as i64);
527        let out_numel = (batch as i64) * (n as i64);
528        if (args.x.data.len() as i64) < in_numel {
529            return Err(Error::BufferTooSmall {
530                needed: in_numel as usize,
531                got: args.x.data.len(),
532            });
533        }
534        if (args.y.data.len() as i64) < out_numel {
535            return Err(Error::BufferTooSmall {
536                needed: out_numel as usize,
537                got: args.y.data.len(),
538            });
539        }
540        if out_numel == 0 {
541            return Ok(());
542        }
543
544        let handle = self.ensure_handle()?;
545        self.bind_stream(handle, stream)?;
546
547        let idata = args.x.data.as_raw().0 as *mut cufftComplex;
548        let odata = args.y.data.as_raw().0 as *mut f32;
549        let status = unsafe { cufftExecC2R(handle, idata, odata) };
550        if status != 0 {
551            return Err(Error::CutlassInternal(cufft_to_status(status)));
552        }
553
554        // Apply 1/n normalization.
555        let scale = 1.0_f32 / (n as f32);
556        let stream_ptr = stream.as_raw() as *mut c_void;
557        let s = unsafe {
558            baracuda_kernels_scale_inplace_real_f32_run(
559                out_numel,
560                scale,
561                odata as *mut c_void,
562                core::ptr::null_mut(),
563                0,
564                stream_ptr,
565            )
566        };
567        map_status(s)
568    }
569}
570
571impl IrfftPlan<f64> {
572    /// Run the C2R FFT (double precision).
573    pub fn run(
574        &self,
575        stream: &Stream,
576        _workspace: Workspace<'_>,
577        args: IrfftArgs<'_, f64, Complex64>,
578    ) -> Result<()> {
579        let n = self.desc.n;
580        let batch = self.desc.batch;
581        let in_shape = [batch, n / 2 + 1];
582        let out_shape = [batch, n];
583        if args.x.shape != in_shape {
584            return Err(Error::InvalidProblem(
585                "baracuda-kernels::IrfftPlan<f64>: x shape != [batch, n/2 + 1]",
586            ));
587        }
588        if args.y.shape != out_shape {
589            return Err(Error::InvalidProblem(
590                "baracuda-kernels::IrfftPlan<f64>: y shape != [batch, n]",
591            ));
592        }
593        let in_numel = (batch as i64) * ((n / 2 + 1) as i64);
594        let out_numel = (batch as i64) * (n as i64);
595        if (args.x.data.len() as i64) < in_numel {
596            return Err(Error::BufferTooSmall {
597                needed: in_numel as usize,
598                got: args.x.data.len(),
599            });
600        }
601        if (args.y.data.len() as i64) < out_numel {
602            return Err(Error::BufferTooSmall {
603                needed: out_numel as usize,
604                got: args.y.data.len(),
605            });
606        }
607        if out_numel == 0 {
608            return Ok(());
609        }
610
611        let handle = self.ensure_handle()?;
612        self.bind_stream(handle, stream)?;
613
614        let idata = args.x.data.as_raw().0 as *mut cufftDoubleComplex;
615        let odata = args.y.data.as_raw().0 as *mut f64;
616        let status = unsafe { cufftExecZ2D(handle, idata, odata) };
617        if status != 0 {
618            return Err(Error::CutlassInternal(cufft_to_status(status)));
619        }
620
621        let scale = 1.0_f64 / (n as f64);
622        let stream_ptr = stream.as_raw() as *mut c_void;
623        let s = unsafe {
624            baracuda_kernels_scale_inplace_real_f64_run(
625                out_numel,
626                scale,
627                odata as *mut c_void,
628                core::ptr::null_mut(),
629                0,
630                stream_ptr,
631            )
632        };
633        map_status(s)
634    }
635}
636
637impl<T: Element> Drop for IrfftPlan<T> {
638    fn drop(&mut self) {
639        let h = self.handle.get();
640        if h != HANDLE_UNINIT {
641            unsafe {
642                let _ = cufftDestroy(h);
643            }
644            self.handle.set(HANDLE_UNINIT);
645        }
646    }
647}