Skip to main content

baracuda_cufft/
lib.rs

1//! Safe Rust wrappers for NVIDIA cuFFT.
2//!
3//! v0.1 covers `cufftPlan1d`/`cufftPlan2d`/`cufftPlan3d` and the R2C/C2R/C2C
4//! single-precision transforms. Multi-GPU (`cufftXt`) and batched
5//! descriptor-style plans land in a follow-up.
6//!
7//! ```no_run
8//! use baracuda_driver::{Context, Device, DeviceBuffer};
9//! use baracuda_cufft::{Plan1d, Transform};
10//!
11//! # fn demo() -> Result<(), Box<dyn std::error::Error>> {
12//! let device = Device::get(0)?;
13//! let ctx = Context::new(&device)?;
14//! let host: Vec<f32> = (0..1024).map(|i| (i as f32 * 0.05).sin()).collect();
15//! let mut d_in = DeviceBuffer::from_slice(&ctx, &host)?;
16//! let mut d_out: DeviceBuffer<baracuda_types::Complex32> =
17//!     DeviceBuffer::new(&ctx, host.len() / 2 + 1)?;
18//!
19//! let plan = Plan1d::new(host.len() as i32, Transform::R2C, 1)?;
20//! plan.exec_r2c(&mut d_in, &mut d_out)?;
21//! # Ok(()) }
22//! ```
23
24#![warn(missing_debug_implementations)]
25
26use baracuda_cufft_sys::{
27    cufft, cufftComplex, cufftDoubleComplex, cufftHandle, cufftResult, cufftType,
28};
29use baracuda_driver::{DeviceBuffer, Stream};
30use baracuda_types::{Complex32, Complex64};
31
32/// Error type for cuFFT operations.
33pub type Error = baracuda_core::Error<cufftResult>;
34/// Result alias.
35pub type Result<T, E = Error> = core::result::Result<T, E>;
36
37#[inline]
38fn check(status: cufftResult) -> Result<()> {
39    Error::check(status)
40}
41
42/// Transform kind.
43#[derive(Copy, Clone, Debug, Eq, PartialEq)]
44pub enum Transform {
45    /// Real → Complex (forward), f32.
46    R2C,
47    /// Complex → Real (inverse), f32.
48    C2R,
49    /// Complex → Complex (f32, direction passed at exec time).
50    C2C,
51    /// Double Real → Complex (forward), f64.
52    D2Z,
53    /// Complex → Double Real (inverse), f64.
54    Z2D,
55    /// Complex → Complex (f64, direction passed at exec time).
56    Z2Z,
57}
58
59impl Transform {
60    fn raw(self) -> cufftType {
61        match self {
62            Transform::R2C => cufftType::R2C,
63            Transform::C2R => cufftType::C2R,
64            Transform::C2C => cufftType::C2C,
65            Transform::D2Z => cufftType::D2Z,
66            Transform::Z2D => cufftType::Z2D,
67            Transform::Z2Z => cufftType::Z2Z,
68        }
69    }
70}
71
72/// Direction for `C2C` transforms.
73#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
74pub enum Direction {
75    /// Forward transform (`exp(-2πi…)`).
76    #[default]
77    Forward,
78    /// Inverse / unnormalized backward transform (`exp(+2πi…)`).
79    Inverse,
80}
81
82impl Direction {
83    fn raw(self) -> core::ffi::c_int {
84        match self {
85            Direction::Forward => baracuda_cufft_sys::CUFFT_FORWARD,
86            Direction::Inverse => baracuda_cufft_sys::CUFFT_INVERSE,
87        }
88    }
89}
90
91/// A 1-D cuFFT plan.
92pub struct Plan1d {
93    handle: cufftHandle,
94}
95
96unsafe impl Send for Plan1d {}
97
98impl core::fmt::Debug for Plan1d {
99    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
100        f.debug_struct("Plan1d")
101            .field("handle", &self.handle)
102            .finish()
103    }
104}
105
106impl Plan1d {
107    /// Create a 1-D plan of length `nx` and `batch` parallel transforms.
108    ///
109    /// # Example
110    ///
111    /// A single forward R2C transform of length 1024.
112    ///
113    /// ```no_run
114    /// use baracuda_driver::{Context, Device, DeviceBuffer};
115    /// use baracuda_cufft::{Plan1d, Transform};
116    /// use baracuda_types::Complex32;
117    ///
118    /// # fn demo() -> Result<(), Box<dyn std::error::Error>> {
119    /// let ctx = Context::new(&Device::get(0)?)?;
120    /// let n = 1024;
121    ///
122    /// let mut input:  DeviceBuffer<f32>       = DeviceBuffer::zeros(&ctx, n)?;
123    /// let mut output: DeviceBuffer<Complex32> = DeviceBuffer::new(&ctx, n / 2 + 1)?;
124    ///
125    /// let plan = Plan1d::new(n as i32, Transform::R2C, 1)?;
126    /// plan.exec_r2c(&mut input, &mut output)?;
127    /// # Ok(()) }
128    /// ```
129    pub fn new(nx: i32, transform: Transform, batch: i32) -> Result<Self> {
130        let c = cufft()?;
131        let cu = c.cufft_plan_1d()?;
132        let mut plan: cufftHandle = 0;
133        check(unsafe { cu(&mut plan, nx, transform.raw(), batch) })?;
134        Ok(Self { handle: plan })
135    }
136
137    /// Bind subsequent exec calls on this plan to `stream`.
138    pub fn set_stream(&self, stream: &Stream) -> Result<()> {
139        let c = cufft()?;
140        let cu = c.cufft_set_stream()?;
141        check(unsafe { cu(self.handle, stream.as_raw() as _) })
142    }
143
144    /// Execute a real-to-complex transform.
145    pub fn exec_r2c(
146        &self,
147        input: &mut DeviceBuffer<f32>,
148        output: &mut DeviceBuffer<Complex32>,
149    ) -> Result<()> {
150        let c = cufft()?;
151        let cu = c.cufft_exec_r2c()?;
152        check(unsafe {
153            cu(
154                self.handle,
155                input.as_raw().0 as *mut f32,
156                output.as_raw().0 as *mut cufftComplex,
157            )
158        })
159    }
160
161    /// Execute a complex-to-real transform.
162    ///
163    /// Plan must have been built with [`Transform::C2R`]. cuFFT inverse R2C
164    /// transforms are unnormalised — divide by `n` to recover the original
165    /// signal.
166    ///
167    /// # Example
168    ///
169    /// Round-trip a length-1024 real signal through R2C then C2R.
170    ///
171    /// ```no_run
172    /// use baracuda_driver::{Context, Device, DeviceBuffer};
173    /// use baracuda_cufft::{Plan1d, Transform};
174    /// use baracuda_types::Complex32;
175    ///
176    /// # fn demo() -> Result<(), Box<dyn std::error::Error>> {
177    /// let ctx = Context::new(&Device::get(0)?)?;
178    /// let n = 1024;
179    ///
180    /// let mut signal: DeviceBuffer<f32>       = DeviceBuffer::zeros(&ctx, n)?;
181    /// let mut spectrum: DeviceBuffer<Complex32> = DeviceBuffer::new(&ctx, n / 2 + 1)?;
182    /// let mut recovered: DeviceBuffer<f32>    = DeviceBuffer::zeros(&ctx, n)?;
183    ///
184    /// let fwd = Plan1d::new(n as i32, Transform::R2C, 1)?;
185    /// let inv = Plan1d::new(n as i32, Transform::C2R, 1)?;
186    /// fwd.exec_r2c(&mut signal, &mut spectrum)?;
187    /// inv.exec_c2r(&mut spectrum, &mut recovered)?;
188    /// // `recovered` now holds signal * n; divide by n for the original.
189    /// # Ok(()) }
190    /// ```
191    pub fn exec_c2r(
192        &self,
193        input: &mut DeviceBuffer<Complex32>,
194        output: &mut DeviceBuffer<f32>,
195    ) -> Result<()> {
196        let c = cufft()?;
197        let cu = c.cufft_exec_c2r()?;
198        check(unsafe {
199            cu(
200                self.handle,
201                input.as_raw().0 as *mut cufftComplex,
202                output.as_raw().0 as *mut f32,
203            )
204        })
205    }
206
207    /// Execute a complex-to-complex transform in the given direction.
208    pub fn exec_c2c(
209        &self,
210        input: &mut DeviceBuffer<Complex32>,
211        output: &mut DeviceBuffer<Complex32>,
212        direction: Direction,
213    ) -> Result<()> {
214        let c = cufft()?;
215        let cu = c.cufft_exec_c2c()?;
216        check(unsafe {
217            cu(
218                self.handle,
219                input.as_raw().0 as *mut cufftComplex,
220                output.as_raw().0 as *mut cufftComplex,
221                direction.raw(),
222            )
223        })
224    }
225
226    /// Raw `cufftHandle`. Use with care.
227    #[inline]
228    pub fn as_raw(&self) -> cufftHandle {
229        self.handle
230    }
231}
232
233impl Drop for Plan1d {
234    fn drop(&mut self) {
235        if let Ok(c) = cufft() {
236            if let Ok(cu) = c.cufft_destroy() {
237                let _ = unsafe { cu(self.handle) };
238            }
239        }
240    }
241}
242
243/// A 2-D cuFFT plan.
244pub struct Plan2d {
245    handle: cufftHandle,
246}
247
248unsafe impl Send for Plan2d {}
249
250impl core::fmt::Debug for Plan2d {
251    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
252        f.debug_struct("Plan2d")
253            .field("handle", &self.handle)
254            .finish()
255    }
256}
257
258impl Plan2d {
259    /// Create a 2-D plan of dimensions `nx × ny`.
260    ///
261    /// # Example
262    ///
263    /// Forward 2-D C2C FFT of a 128×128 complex image.
264    ///
265    /// ```no_run
266    /// use baracuda_driver::{Context, Device, DeviceBuffer};
267    /// use baracuda_cufft::{Direction, Plan2d, Transform};
268    /// use baracuda_types::Complex32;
269    ///
270    /// # fn demo() -> Result<(), Box<dyn std::error::Error>> {
271    /// let ctx = Context::new(&Device::get(0)?)?;
272    /// let (nx, ny) = (128, 128);
273    ///
274    /// let mut img:      DeviceBuffer<Complex32> = DeviceBuffer::new(&ctx, (nx * ny) as usize)?;
275    /// let mut spectrum: DeviceBuffer<Complex32> = DeviceBuffer::new(&ctx, (nx * ny) as usize)?;
276    ///
277    /// let plan = Plan2d::new(nx, ny, Transform::C2C)?;
278    /// plan.exec_c2c(&mut img, &mut spectrum, Direction::Forward)?;
279    /// # Ok(()) }
280    /// ```
281    pub fn new(nx: i32, ny: i32, transform: Transform) -> Result<Self> {
282        let c = cufft()?;
283        let cu = c.cufft_plan_2d()?;
284        let mut plan: cufftHandle = 0;
285        check(unsafe { cu(&mut plan, nx, ny, transform.raw()) })?;
286        Ok(Self { handle: plan })
287    }
288
289    /// Execute a complex-to-complex 2D transform.
290    pub fn exec_c2c(
291        &self,
292        input: &mut DeviceBuffer<Complex32>,
293        output: &mut DeviceBuffer<Complex32>,
294        direction: Direction,
295    ) -> Result<()> {
296        let c = cufft()?;
297        let cu = c.cufft_exec_c2c()?;
298        check(unsafe {
299            cu(
300                self.handle,
301                input.as_raw().0 as *mut cufftComplex,
302                output.as_raw().0 as *mut cufftComplex,
303                direction.raw(),
304            )
305        })
306    }
307
308    /// Raw handle.
309    #[inline]
310    pub fn as_raw(&self) -> cufftHandle {
311        self.handle
312    }
313}
314
315impl Drop for Plan2d {
316    fn drop(&mut self) {
317        if let Ok(c) = cufft() {
318            if let Ok(cu) = c.cufft_destroy() {
319                let _ = unsafe { cu(self.handle) };
320            }
321        }
322    }
323}
324
325/// Owned 3-D cuFFT plan.
326pub struct Plan3d {
327    handle: cufftHandle,
328}
329
330unsafe impl Send for Plan3d {}
331
332impl core::fmt::Debug for Plan3d {
333    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
334        f.debug_struct("Plan3d")
335            .field("handle", &self.handle)
336            .finish()
337    }
338}
339
340impl Plan3d {
341    /// Create a 3-D plan of dimensions `nx × ny × nz`.
342    pub fn new(nx: i32, ny: i32, nz: i32, transform: Transform) -> Result<Self> {
343        let c = cufft()?;
344        let cu = c.cufft_plan_3d()?;
345        let mut plan: cufftHandle = 0;
346        check(unsafe { cu(&mut plan, nx, ny, nz, transform.raw()) })?;
347        Ok(Self { handle: plan })
348    }
349
350    /// Bind subsequent exec calls on this plan to `stream`.
351    pub fn set_stream(&self, stream: &Stream) -> Result<()> {
352        let c = cufft()?;
353        let cu = c.cufft_set_stream()?;
354        check(unsafe { cu(self.handle, stream.as_raw() as _) })
355    }
356
357    /// Execute a 3-D complex-to-complex transform in the given direction.
358    pub fn exec_c2c(
359        &self,
360        input: &mut DeviceBuffer<Complex32>,
361        output: &mut DeviceBuffer<Complex32>,
362        direction: Direction,
363    ) -> Result<()> {
364        let c = cufft()?;
365        let cu = c.cufft_exec_c2c()?;
366        check(unsafe {
367            cu(
368                self.handle,
369                input.as_raw().0 as *mut cufftComplex,
370                output.as_raw().0 as *mut cufftComplex,
371                direction.raw(),
372            )
373        })
374    }
375
376    /// Raw `cufftHandle`. Use with care.
377    #[inline]
378    pub fn as_raw(&self) -> cufftHandle {
379        self.handle
380    }
381}
382
383impl Drop for Plan3d {
384    fn drop(&mut self) {
385        if let Ok(c) = cufft() {
386            if let Ok(cu) = c.cufft_destroy() {
387                let _ = unsafe { cu(self.handle) };
388            }
389        }
390    }
391}
392
393/// cuFFT library version, e.g. `11300` for cuFFT 11.3.0.
394pub fn version() -> Result<i32> {
395    let c = cufft()?;
396    let cu = c.cufft_get_version()?;
397    let mut v: core::ffi::c_int = 0;
398    check(unsafe { cu(&mut v) })?;
399    Ok(v)
400}
401
402// =======================================================================
403// Double-precision exec + PlanMany (batched) + XT multi-GPU
404// =======================================================================
405
406macro_rules! exec_z_impls {
407    ($plan:ty) => {
408        impl $plan {
409            /// Execute D → Z (double-precision R2C). Plan must have been
410            /// built with `Transform::D2Z`.
411            pub fn exec_d2z(
412                &self,
413                input: &mut DeviceBuffer<f64>,
414                output: &mut DeviceBuffer<Complex64>,
415            ) -> Result<()> {
416                let c = cufft()?;
417                let cu = c.cufft_exec_d2z()?;
418                check(unsafe {
419                    cu(
420                        self.handle,
421                        input.as_raw().0 as *mut f64,
422                        output.as_raw().0 as *mut cufftDoubleComplex,
423                    )
424                })
425            }
426
427            /// Execute Z → D (double-precision C2R).
428            pub fn exec_z2d(
429                &self,
430                input: &mut DeviceBuffer<Complex64>,
431                output: &mut DeviceBuffer<f64>,
432            ) -> Result<()> {
433                let c = cufft()?;
434                let cu = c.cufft_exec_z2d()?;
435                check(unsafe {
436                    cu(
437                        self.handle,
438                        input.as_raw().0 as *mut cufftDoubleComplex,
439                        output.as_raw().0 as *mut f64,
440                    )
441                })
442            }
443
444            /// Execute Z → Z (double-precision C2C). Direction passed at exec time.
445            pub fn exec_z2z(
446                &self,
447                input: &mut DeviceBuffer<Complex64>,
448                output: &mut DeviceBuffer<Complex64>,
449                direction: Direction,
450            ) -> Result<()> {
451                let c = cufft()?;
452                let cu = c.cufft_exec_z2z()?;
453                check(unsafe {
454                    cu(
455                        self.handle,
456                        input.as_raw().0 as *mut cufftDoubleComplex,
457                        output.as_raw().0 as *mut cufftDoubleComplex,
458                        direction.raw(),
459                    )
460                })
461            }
462        }
463    };
464}
465
466exec_z_impls!(Plan1d);
467exec_z_impls!(Plan2d);
468
469/// A batched / many-rank plan (`cufftPlanMany`). Handles arbitrary
470/// rank + advanced-data-layout transforms.
471#[derive(Debug)]
472pub struct PlanMany {
473    handle: cufftHandle,
474}
475
476impl PlanMany {
477    /// Construct a batched plan. `n[rank]` is the transform shape;
478    /// `inembed` / `onembed` are the actual memory layouts of in/out
479    /// (pass `None` for packed). `istride`/`ostride` are element strides
480    /// between successive elements; `idist`/`odist` are element strides
481    /// between successive batches.
482    ///
483    /// # Example
484    ///
485    /// 32 packed 1-D R2C transforms of length 256 (e.g., a STFT frame).
486    ///
487    /// ```no_run
488    /// use baracuda_driver::{Context, Device, DeviceBuffer};
489    /// use baracuda_cufft::{PlanMany, Transform};
490    /// use baracuda_types::Complex32;
491    ///
492    /// # fn demo() -> Result<(), Box<dyn std::error::Error>> {
493    /// let ctx = Context::new(&Device::get(0)?)?;
494    /// let n_per = 256i32;
495    /// let batch = 32i32;
496    /// let mut n = [n_per];
497    ///
498    /// // Packed contiguous layout: pass None for embeds, strides = 1, dist = transform length.
499    /// let plan = PlanMany::new(
500    ///     /* rank   */ 1,
501    ///     /* n      */ &mut n,
502    ///     /* inemb  */ None,
503    ///     /* istr   */ 1,
504    ///     /* idist  */ n_per,
505    ///     /* onemb  */ None,
506    ///     /* ostr   */ 1,
507    ///     /* odist  */ n_per / 2 + 1,
508    ///     /* type   */ Transform::R2C,
509    ///     /* batch  */ batch,
510    /// )?;
511    ///
512    /// let mut input:  DeviceBuffer<f32>       =
513    ///     DeviceBuffer::zeros(&ctx, (n_per * batch) as usize)?;
514    /// let mut output: DeviceBuffer<Complex32> =
515    ///     DeviceBuffer::new(&ctx, ((n_per / 2 + 1) * batch) as usize)?;
516    /// plan.exec_r2c(&mut input, &mut output)?;
517    /// # Ok(()) }
518    /// ```
519    #[allow(clippy::too_many_arguments)]
520    pub fn new(
521        rank: i32,
522        n: &mut [i32],
523        inembed: Option<&mut [i32]>,
524        istride: i32,
525        idist: i32,
526        onembed: Option<&mut [i32]>,
527        ostride: i32,
528        odist: i32,
529        ty: Transform,
530        batch: i32,
531    ) -> Result<Self> {
532        let c = cufft()?;
533        let cu = c.cufft_plan_many()?;
534        let mut h: cufftHandle = 0;
535        check(unsafe {
536            cu(
537                &mut h,
538                rank,
539                n.as_mut_ptr(),
540                inembed.map_or(core::ptr::null_mut(), |s| s.as_mut_ptr()),
541                istride,
542                idist,
543                onembed.map_or(core::ptr::null_mut(), |s| s.as_mut_ptr()),
544                ostride,
545                odist,
546                ty.raw(),
547                batch,
548            )
549        })?;
550        Ok(Self { handle: h })
551    }
552
553    /// Raw `cufftHandle`. Use with care.
554    #[inline]
555    pub fn as_raw(&self) -> cufftHandle {
556        self.handle
557    }
558
559    /// Bind the plan to a CUDA stream.
560    pub fn set_stream(&self, stream: &Stream) -> Result<()> {
561        let c = cufft()?;
562        let cu = c.cufft_set_stream()?;
563        check(unsafe { cu(self.handle, stream.as_raw() as _) })
564    }
565}
566
567impl Drop for PlanMany {
568    fn drop(&mut self) {
569        if let Ok(c) = cufft() {
570            if let Ok(cu) = c.cufft_destroy() {
571                let _ = unsafe { cu(self.handle) };
572            }
573        }
574    }
575}
576
577exec_z_impls!(PlanMany);
578
579impl PlanMany {
580    /// Execute R → C (single-precision R2C).
581    pub fn exec_r2c(
582        &self,
583        input: &mut DeviceBuffer<f32>,
584        output: &mut DeviceBuffer<Complex32>,
585    ) -> Result<()> {
586        let c = cufft()?;
587        let cu = c.cufft_exec_r2c()?;
588        check(unsafe {
589            cu(
590                self.handle,
591                input.as_raw().0 as *mut f32,
592                output.as_raw().0 as *mut cufftComplex,
593            )
594        })
595    }
596
597    /// Execute C → R (single-precision C2R).
598    pub fn exec_c2r(
599        &self,
600        input: &mut DeviceBuffer<Complex32>,
601        output: &mut DeviceBuffer<f32>,
602    ) -> Result<()> {
603        let c = cufft()?;
604        let cu = c.cufft_exec_c2r()?;
605        check(unsafe {
606            cu(
607                self.handle,
608                input.as_raw().0 as *mut cufftComplex,
609                output.as_raw().0 as *mut f32,
610            )
611        })
612    }
613
614    /// Execute C → C.
615    pub fn exec_c2c(
616        &self,
617        input: &mut DeviceBuffer<Complex32>,
618        output: &mut DeviceBuffer<Complex32>,
619        direction: Direction,
620    ) -> Result<()> {
621        let c = cufft()?;
622        let cu = c.cufft_exec_c2c()?;
623        check(unsafe {
624            cu(
625                self.handle,
626                input.as_raw().0 as *mut cufftComplex,
627                output.as_raw().0 as *mut cufftComplex,
628                direction.raw(),
629            )
630        })
631    }
632}
633
634/// Sizing estimates (workspace bytes) for a plan shape.
635pub fn estimate_1d(nx: i32, ty: Transform, batch: i32) -> Result<usize> {
636    let c = cufft()?;
637    let cu = c.cufft_estimate_1d()?;
638    let mut s: usize = 0;
639    check(unsafe { cu(nx, ty.raw(), batch, &mut s) })?;
640    Ok(s)
641}
642
643/// 2-D workspace-size estimate (bytes) for a plan of the given shape.
644/// Wraps `cufftEstimate2d`.
645pub fn estimate_2d(nx: i32, ny: i32, ty: Transform) -> Result<usize> {
646    let c = cufft()?;
647    let cu = c.cufft_estimate_2d()?;
648    let mut s: usize = 0;
649    check(unsafe { cu(nx, ny, ty.raw(), &mut s) })?;
650    Ok(s)
651}
652
653/// 3-D workspace-size estimate (bytes) for a plan of the given shape.
654/// Wraps `cufftEstimate3d`.
655pub fn estimate_3d(nx: i32, ny: i32, nz: i32, ty: Transform) -> Result<usize> {
656    let c = cufft()?;
657    let cu = c.cufft_estimate_3d()?;
658    let mut s: usize = 0;
659    check(unsafe { cu(nx, ny, nz, ty.raw(), &mut s) })?;
660    Ok(s)
661}
662
663/// Multi-GPU (XT) extension helpers. Use these to distribute a cuFFT
664/// plan across multiple GPUs via `cufftXtSetGPUs` + `cufftXtExec`.
665pub mod xt {
666    use super::*;
667
668    /// Spread a plan across `which_gpus` (CUDA device ordinals).
669    ///
670    /// # Safety
671    ///
672    /// `plan` must be a fresh (unexecuted) handle; all ordinals in
673    /// `which_gpus` must be live CUDA devices.
674    pub unsafe fn set_gpus(plan: cufftHandle, which_gpus: &mut [i32]) -> Result<()> { unsafe {
675        let c = cufft()?;
676        let cu = c.cufft_xt_set_gpus()?;
677        check(cu(plan, which_gpus.len() as i32, which_gpus.as_mut_ptr()))
678    }}
679
680    /// Allocate a multi-GPU `cudaLibXtDesc*` matching the plan.
681    /// Returns an opaque pointer that must be freed with [`free`].
682    ///
683    /// # Safety
684    ///
685    /// `plan` must have been configured with [`set_gpus`] first.
686    pub unsafe fn malloc(
687        plan: cufftHandle,
688        subformat: i32,
689    ) -> Result<*mut core::ffi::c_void> { unsafe {
690        let c = cufft()?;
691        let cu = c.cufft_xt_malloc()?;
692        let mut desc: *mut core::ffi::c_void = core::ptr::null_mut();
693        check(cu(plan, &mut desc, subformat))?;
694        Ok(desc)
695    }}
696
697    /// Free an XT descriptor from [`malloc`].
698    ///
699    /// # Safety
700    ///
701    /// `desc` must come from [`malloc`].
702    pub unsafe fn free(desc: *mut core::ffi::c_void) -> Result<()> { unsafe {
703        let c = cufft()?;
704        let cu = c.cufft_xt_free()?;
705        check(cu(desc))
706    }}
707
708    /// Multi-GPU memcpy between host / device / XT descriptors.
709    ///
710    /// # Safety
711    ///
712    /// Pointer kinds and `ty` must agree.
713    pub unsafe fn memcpy(
714        plan: cufftHandle,
715        dst: *mut core::ffi::c_void,
716        src: *mut core::ffi::c_void,
717        ty: i32,
718    ) -> Result<()> { unsafe {
719        let c = cufft()?;
720        let cu = c.cufft_xt_memcpy()?;
721        check(cu(plan, dst, src, ty))
722    }}
723
724    /// Execute the plan on its XT descriptors.
725    ///
726    /// # Safety
727    ///
728    /// `input` / `output` must be `cudaLibXtDesc*` pointers matching the plan.
729    pub unsafe fn exec_descriptor(
730        plan: cufftHandle,
731        input: *mut core::ffi::c_void,
732        output: *mut core::ffi::c_void,
733        direction: Direction,
734    ) -> Result<()> { unsafe {
735        let c = cufft()?;
736        let cu = c.cufft_xt_exec_descriptor()?;
737        check(cu(plan, input, output, direction.raw()))
738    }}
739}
740
741/// Set a user-allocated scratch work area (`cufftSetWorkArea`).
742///
743/// # Safety
744///
745/// `plan` must have `SetAutoAllocation(false)` first; `work_area` must
746/// be a live device pointer.
747pub unsafe fn set_work_area(plan: cufftHandle, work_area: *mut core::ffi::c_void) -> Result<()> { unsafe {
748    let c = cufft()?;
749    let cu = c.cufft_set_work_area()?;
750    check(cu(plan, work_area))
751}}
752
753/// Disable / re-enable automatic work-area allocation.
754pub fn set_auto_allocation(plan: cufftHandle, auto: bool) -> Result<()> {
755    let c = cufft()?;
756    let cu = c.cufft_set_auto_allocation()?;
757    check(unsafe { cu(plan, if auto { 1 } else { 0 }) })
758}
759
760/// Scratch bytes this plan currently needs.
761pub fn get_size(plan: cufftHandle) -> Result<usize> {
762    let c = cufft()?;
763    let cu = c.cufft_get_size()?;
764    let mut s: usize = 0;
765    check(unsafe { cu(plan, &mut s) })?;
766    Ok(s)
767}
768
769// ============================================================================
770// Two-step plan creation: cufftCreate + cufftMakePlan*
771// ============================================================================
772
773/// Generic cuFFT plan that supports the modern two-step creation flow:
774/// [`Plan::create`] gets you a fresh handle, then one of the
775/// `make_plan_*` methods configures the rank and reports the workspace
776/// size — useful when you want to allocate the work area yourself
777/// (call [`set_auto_allocation(false)`] before `make_plan_*`).
778///
779/// Unlike [`Plan1d`] / [`Plan2d`] / [`Plan3d`], this type can hold any
780/// rank, including the batched-many / 64-bit-many forms.
781pub struct Plan {
782    handle: cufftHandle,
783}
784
785unsafe impl Send for Plan {}
786
787impl core::fmt::Debug for Plan {
788    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
789        f.debug_struct("Plan")
790            .field("handle", &self.handle)
791            .finish()
792    }
793}
794
795impl Plan {
796    /// Allocate a fresh empty plan. The plan is unusable until you
797    /// finalize it with one of the `make_plan_*` methods.
798    pub fn create() -> Result<Self> {
799        let c = cufft()?;
800        let cu = c.cufft_create()?;
801        let mut plan: cufftHandle = 0;
802        check(unsafe { cu(&mut plan) })?;
803        Ok(Self { handle: plan })
804    }
805
806    /// Finalize as a 1-D plan of length `nx` and `batch` parallel
807    /// transforms. Returns the workspace size in bytes.
808    pub fn make_plan_1d(&self, nx: i32, transform: Transform, batch: i32) -> Result<usize> {
809        let c = cufft()?;
810        let cu = c.cufft_make_plan_1d()?;
811        let mut size: usize = 0;
812        check(unsafe { cu(self.handle, nx, transform.raw(), batch, &mut size) })?;
813        Ok(size)
814    }
815
816    /// Finalize as a 2-D plan. Returns workspace size in bytes.
817    pub fn make_plan_2d(&self, nx: i32, ny: i32, transform: Transform) -> Result<usize> {
818        let c = cufft()?;
819        let cu = c.cufft_make_plan_2d()?;
820        let mut size: usize = 0;
821        check(unsafe { cu(self.handle, nx, ny, transform.raw(), &mut size) })?;
822        Ok(size)
823    }
824
825    /// Finalize as a 3-D plan. Returns workspace size in bytes.
826    pub fn make_plan_3d(
827        &self,
828        nx: i32,
829        ny: i32,
830        nz: i32,
831        transform: Transform,
832    ) -> Result<usize> {
833        let c = cufft()?;
834        let cu = c.cufft_make_plan_3d()?;
835        let mut size: usize = 0;
836        check(unsafe { cu(self.handle, nx, ny, nz, transform.raw(), &mut size) })?;
837        Ok(size)
838    }
839
840    /// Finalize as a generic strided/batched plan. Returns workspace
841    /// size in bytes.
842    ///
843    /// # Safety
844    ///
845    /// `n`, `inembed`, `onembed` must be writable arrays of length
846    /// `rank` (cuFFT mutates them in some versions). Pass null for
847    /// `inembed` / `onembed` to use defaults.
848    #[allow(clippy::too_many_arguments)]
849    pub unsafe fn make_plan_many(
850        &self,
851        rank: i32,
852        n: &mut [i32],
853        inembed: *mut i32,
854        istride: i32,
855        idist: i32,
856        onembed: *mut i32,
857        ostride: i32,
858        odist: i32,
859        transform: Transform,
860        batch: i32,
861    ) -> Result<usize> { unsafe {
862        assert_eq!(n.len() as i32, rank, "n.len() must equal rank");
863        let c = cufft()?;
864        let cu = c.cufft_make_plan_many()?;
865        let mut size: usize = 0;
866        check(cu(
867            self.handle,
868            rank,
869            n.as_mut_ptr(),
870            inembed,
871            istride,
872            idist,
873            onembed,
874            ostride,
875            odist,
876            transform.raw(),
877            batch,
878            &mut size,
879        ))?;
880        Ok(size)
881    }}
882
883    /// 64-bit variant of [`make_plan_many`] — use this when any
884    /// dimension or stride exceeds `i32::MAX`.
885    ///
886    /// # Safety
887    ///
888    /// Same as [`make_plan_many`].
889    #[allow(clippy::too_many_arguments)]
890    pub unsafe fn make_plan_many64(
891        &self,
892        rank: i32,
893        n: &mut [i64],
894        inembed: *mut i64,
895        istride: i64,
896        idist: i64,
897        onembed: *mut i64,
898        ostride: i64,
899        odist: i64,
900        transform: Transform,
901        batch: i64,
902    ) -> Result<usize> { unsafe {
903        assert_eq!(n.len() as i32, rank, "n.len() must equal rank");
904        let c = cufft()?;
905        let cu = c.cufft_make_plan_many64()?;
906        let mut size: usize = 0;
907        check(cu(
908            self.handle,
909            rank,
910            n.as_mut_ptr(),
911            inembed,
912            istride,
913            idist,
914            onembed,
915            ostride,
916            odist,
917            transform.raw(),
918            batch,
919            &mut size,
920        ))?;
921        Ok(size)
922    }}
923
924    /// Bind subsequent exec calls to `stream`.
925    pub fn set_stream(&self, stream: &Stream) -> Result<()> {
926        let c = cufft()?;
927        let cu = c.cufft_set_stream()?;
928        check(unsafe { cu(self.handle, stream.as_raw() as _) })
929    }
930
931    /// Raw `cufftHandle`. Use with care.
932    #[inline]
933    pub fn as_raw(&self) -> cufftHandle {
934        self.handle
935    }
936}
937
938impl Drop for Plan {
939    fn drop(&mut self) {
940        if let Ok(c) = cufft() {
941            if let Ok(cu) = c.cufft_destroy() {
942                let _ = unsafe { cu(self.handle) };
943            }
944        }
945    }
946}
947
948// ============================================================================
949// Callback API (cufftXtSetCallback / Clear / SetCallbackSharedSize)
950// ============================================================================
951
952pub mod callback {
953    //! Pre/post callbacks attached to a cuFFT plan via the `cufftXt*`
954    //! callback entry points. The callback ABI is fixed by NVIDIA: each
955    //! callback receives the input/output element index, a caller-info
956    //! pointer, and the data; see the cuFFT reference for the exact
957    //! signatures by callback type. We expose only the raw setters
958    //! because the function-pointer types are PTX-shaped, not regular
959    //! `extern "C"` functions — they ship as device-side `__device__`
960    //! functions linked into the user's CUBIN.
961    use super::*;
962
963    /// Callback type values from the cuFFT header
964    /// (`cufftXtCallbackType`).
965    #[derive(Copy, Clone, Debug, Eq, PartialEq)]
966    #[repr(i32)]
967    pub enum CallbackType {
968        /// Load callback for complex single precision.
969        LoadComplex = 0,
970        /// Load callback for complex double precision.
971        LoadDoubleComplex = 1,
972        /// Load callback for real single precision.
973        LoadReal = 2,
974        /// Load callback for real double precision.
975        LoadDoubleReal = 3,
976        /// Store callback for complex single precision.
977        StoreComplex = 4,
978        /// Store callback for complex double precision.
979        StoreDoubleComplex = 5,
980        /// Store callback for real single precision.
981        StoreReal = 6,
982        /// Store callback for real double precision.
983        StoreDoubleReal = 7,
984    }
985
986    /// Attach a load/store callback to the plan. `callback_routine` is
987    /// an array of device function pointers — one per GPU for
988    /// multi-GPU plans, otherwise a single-element array.
989    /// `caller_info` is parallel; pass null for "no caller info".
990    ///
991    /// # Safety
992    ///
993    /// `callback_routine[i]` must be a `__device__` function with the
994    /// signature cuFFT expects for `cb_type` (see the cuFFT reference).
995    /// The routine and `caller_info` must outlive every plan exec call.
996    pub unsafe fn set(
997        plan: cufftHandle,
998        callback_routine: &mut [*mut core::ffi::c_void],
999        cb_type: CallbackType,
1000        caller_info: &mut [*mut core::ffi::c_void],
1001    ) -> Result<()> { unsafe {
1002        assert_eq!(
1003            callback_routine.len(),
1004            caller_info.len(),
1005            "callback_routine and caller_info must have the same length"
1006        );
1007        let c = cufft()?;
1008        let cu = c.cufft_xt_set_callback()?;
1009        check(cu(
1010            plan,
1011            callback_routine.as_mut_ptr(),
1012            cb_type as i32,
1013            caller_info.as_mut_ptr(),
1014        ))
1015    }}
1016
1017    /// Detach any previously set callback of `cb_type`.
1018    pub fn clear(plan: cufftHandle, cb_type: CallbackType) -> Result<()> {
1019        let c = cufft()?;
1020        let cu = c.cufft_xt_clear_callback()?;
1021        check(unsafe { cu(plan, cb_type as i32) })
1022    }
1023
1024    /// Reserve `shared_size` bytes of dynamic shared memory per kernel
1025    /// for the callback. Maximum permissible value is GPU-dependent.
1026    pub fn set_shared_size(
1027        plan: cufftHandle,
1028        cb_type: CallbackType,
1029        shared_size: usize,
1030    ) -> Result<()> {
1031        let c = cufft()?;
1032        let cu = c.cufft_xt_set_callback_shared_size()?;
1033        check(unsafe { cu(plan, cb_type as i32, shared_size) })
1034    }
1035}