baracuda_cufft/lib.rs
1//! Safe Rust wrappers for NVIDIA cuFFT.
2//!
3//! v0.1 covers `cufftPlan1d`/`cufftPlan2d`/`cufftPlan3d` and the R2C/C2R/C2C
4//! single-precision transforms. Multi-GPU (`cufftXt`) and batched
5//! descriptor-style plans land in a follow-up.
6//!
7//! ```no_run
8//! use baracuda_driver::{Context, Device, DeviceBuffer};
9//! use baracuda_cufft::{Plan1d, Transform};
10//!
11//! # fn demo() -> Result<(), Box<dyn std::error::Error>> {
12//! let device = Device::get(0)?;
13//! let ctx = Context::new(&device)?;
14//! let host: Vec<f32> = (0..1024).map(|i| (i as f32 * 0.05).sin()).collect();
15//! let mut d_in = DeviceBuffer::from_slice(&ctx, &host)?;
16//! let mut d_out: DeviceBuffer<baracuda_types::Complex32> =
17//! DeviceBuffer::new(&ctx, host.len() / 2 + 1)?;
18//!
19//! let plan = Plan1d::new(host.len() as i32, Transform::R2C, 1)?;
20//! plan.exec_r2c(&mut d_in, &mut d_out)?;
21//! # Ok(()) }
22//! ```
23
24#![warn(missing_debug_implementations)]
25
26use baracuda_cufft_sys::{
27 cufft, cufftComplex, cufftDoubleComplex, cufftHandle, cufftResult, cufftType,
28};
29use baracuda_driver::{DeviceBuffer, Stream};
30use baracuda_types::{Complex32, Complex64};
31
32/// Error type for cuFFT operations.
33pub type Error = baracuda_core::Error<cufftResult>;
34/// Result alias.
35pub type Result<T, E = Error> = core::result::Result<T, E>;
36
37#[inline]
38fn check(status: cufftResult) -> Result<()> {
39 Error::check(status)
40}
41
42/// Transform kind.
43#[derive(Copy, Clone, Debug, Eq, PartialEq)]
44pub enum Transform {
45 /// Real → Complex (forward), f32.
46 R2C,
47 /// Complex → Real (inverse), f32.
48 C2R,
49 /// Complex → Complex (f32, direction passed at exec time).
50 C2C,
51 /// Double Real → Complex (forward), f64.
52 D2Z,
53 /// Complex → Double Real (inverse), f64.
54 Z2D,
55 /// Complex → Complex (f64, direction passed at exec time).
56 Z2Z,
57}
58
59impl Transform {
60 fn raw(self) -> cufftType {
61 match self {
62 Transform::R2C => cufftType::R2C,
63 Transform::C2R => cufftType::C2R,
64 Transform::C2C => cufftType::C2C,
65 Transform::D2Z => cufftType::D2Z,
66 Transform::Z2D => cufftType::Z2D,
67 Transform::Z2Z => cufftType::Z2Z,
68 }
69 }
70}
71
72/// Direction for `C2C` transforms.
73#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
74pub enum Direction {
75 /// Forward transform (`exp(-2πi…)`).
76 #[default]
77 Forward,
78 /// Inverse / unnormalized backward transform (`exp(+2πi…)`).
79 Inverse,
80}
81
82impl Direction {
83 fn raw(self) -> core::ffi::c_int {
84 match self {
85 Direction::Forward => baracuda_cufft_sys::CUFFT_FORWARD,
86 Direction::Inverse => baracuda_cufft_sys::CUFFT_INVERSE,
87 }
88 }
89}
90
91/// A 1-D cuFFT plan.
92pub struct Plan1d {
93 handle: cufftHandle,
94}
95
96unsafe impl Send for Plan1d {}
97
98impl core::fmt::Debug for Plan1d {
99 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
100 f.debug_struct("Plan1d")
101 .field("handle", &self.handle)
102 .finish()
103 }
104}
105
106impl Plan1d {
107 /// Create a 1-D plan of length `nx` and `batch` parallel transforms.
108 ///
109 /// # Example
110 ///
111 /// A single forward R2C transform of length 1024.
112 ///
113 /// ```no_run
114 /// use baracuda_driver::{Context, Device, DeviceBuffer};
115 /// use baracuda_cufft::{Plan1d, Transform};
116 /// use baracuda_types::Complex32;
117 ///
118 /// # fn demo() -> Result<(), Box<dyn std::error::Error>> {
119 /// let ctx = Context::new(&Device::get(0)?)?;
120 /// let n = 1024;
121 ///
122 /// let mut input: DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, n)?;
123 /// let mut output: DeviceBuffer<Complex32> = DeviceBuffer::new(&ctx, n / 2 + 1)?;
124 ///
125 /// let plan = Plan1d::new(n as i32, Transform::R2C, 1)?;
126 /// plan.exec_r2c(&mut input, &mut output)?;
127 /// # Ok(()) }
128 /// ```
129 pub fn new(nx: i32, transform: Transform, batch: i32) -> Result<Self> {
130 let c = cufft()?;
131 let cu = c.cufft_plan_1d()?;
132 let mut plan: cufftHandle = 0;
133 check(unsafe { cu(&mut plan, nx, transform.raw(), batch) })?;
134 Ok(Self { handle: plan })
135 }
136
137 /// Bind subsequent exec calls on this plan to `stream`.
138 pub fn set_stream(&self, stream: &Stream) -> Result<()> {
139 let c = cufft()?;
140 let cu = c.cufft_set_stream()?;
141 check(unsafe { cu(self.handle, stream.as_raw() as _) })
142 }
143
144 /// Execute a real-to-complex transform.
145 pub fn exec_r2c(
146 &self,
147 input: &mut DeviceBuffer<f32>,
148 output: &mut DeviceBuffer<Complex32>,
149 ) -> Result<()> {
150 let c = cufft()?;
151 let cu = c.cufft_exec_r2c()?;
152 check(unsafe {
153 cu(
154 self.handle,
155 input.as_raw().0 as *mut f32,
156 output.as_raw().0 as *mut cufftComplex,
157 )
158 })
159 }
160
161 /// Execute a complex-to-real transform.
162 ///
163 /// Plan must have been built with [`Transform::C2R`]. cuFFT inverse R2C
164 /// transforms are unnormalised — divide by `n` to recover the original
165 /// signal.
166 ///
167 /// # Example
168 ///
169 /// Round-trip a length-1024 real signal through R2C then C2R.
170 ///
171 /// ```no_run
172 /// use baracuda_driver::{Context, Device, DeviceBuffer};
173 /// use baracuda_cufft::{Plan1d, Transform};
174 /// use baracuda_types::Complex32;
175 ///
176 /// # fn demo() -> Result<(), Box<dyn std::error::Error>> {
177 /// let ctx = Context::new(&Device::get(0)?)?;
178 /// let n = 1024;
179 ///
180 /// let mut signal: DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, n)?;
181 /// let mut spectrum: DeviceBuffer<Complex32> = DeviceBuffer::new(&ctx, n / 2 + 1)?;
182 /// let mut recovered: DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, n)?;
183 ///
184 /// let fwd = Plan1d::new(n as i32, Transform::R2C, 1)?;
185 /// let inv = Plan1d::new(n as i32, Transform::C2R, 1)?;
186 /// fwd.exec_r2c(&mut signal, &mut spectrum)?;
187 /// inv.exec_c2r(&mut spectrum, &mut recovered)?;
188 /// // `recovered` now holds signal * n; divide by n for the original.
189 /// # Ok(()) }
190 /// ```
191 pub fn exec_c2r(
192 &self,
193 input: &mut DeviceBuffer<Complex32>,
194 output: &mut DeviceBuffer<f32>,
195 ) -> Result<()> {
196 let c = cufft()?;
197 let cu = c.cufft_exec_c2r()?;
198 check(unsafe {
199 cu(
200 self.handle,
201 input.as_raw().0 as *mut cufftComplex,
202 output.as_raw().0 as *mut f32,
203 )
204 })
205 }
206
207 /// Execute a complex-to-complex transform in the given direction.
208 pub fn exec_c2c(
209 &self,
210 input: &mut DeviceBuffer<Complex32>,
211 output: &mut DeviceBuffer<Complex32>,
212 direction: Direction,
213 ) -> Result<()> {
214 let c = cufft()?;
215 let cu = c.cufft_exec_c2c()?;
216 check(unsafe {
217 cu(
218 self.handle,
219 input.as_raw().0 as *mut cufftComplex,
220 output.as_raw().0 as *mut cufftComplex,
221 direction.raw(),
222 )
223 })
224 }
225
226 /// Raw `cufftHandle`. Use with care.
227 #[inline]
228 pub fn as_raw(&self) -> cufftHandle {
229 self.handle
230 }
231}
232
233impl Drop for Plan1d {
234 fn drop(&mut self) {
235 if let Ok(c) = cufft() {
236 if let Ok(cu) = c.cufft_destroy() {
237 let _ = unsafe { cu(self.handle) };
238 }
239 }
240 }
241}
242
243/// A 2-D cuFFT plan.
244pub struct Plan2d {
245 handle: cufftHandle,
246}
247
248unsafe impl Send for Plan2d {}
249
250impl core::fmt::Debug for Plan2d {
251 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
252 f.debug_struct("Plan2d")
253 .field("handle", &self.handle)
254 .finish()
255 }
256}
257
258impl Plan2d {
259 /// Create a 2-D plan of dimensions `nx × ny`.
260 ///
261 /// # Example
262 ///
263 /// Forward 2-D C2C FFT of a 128×128 complex image.
264 ///
265 /// ```no_run
266 /// use baracuda_driver::{Context, Device, DeviceBuffer};
267 /// use baracuda_cufft::{Direction, Plan2d, Transform};
268 /// use baracuda_types::Complex32;
269 ///
270 /// # fn demo() -> Result<(), Box<dyn std::error::Error>> {
271 /// let ctx = Context::new(&Device::get(0)?)?;
272 /// let (nx, ny) = (128, 128);
273 ///
274 /// let mut img: DeviceBuffer<Complex32> = DeviceBuffer::new(&ctx, (nx * ny) as usize)?;
275 /// let mut spectrum: DeviceBuffer<Complex32> = DeviceBuffer::new(&ctx, (nx * ny) as usize)?;
276 ///
277 /// let plan = Plan2d::new(nx, ny, Transform::C2C)?;
278 /// plan.exec_c2c(&mut img, &mut spectrum, Direction::Forward)?;
279 /// # Ok(()) }
280 /// ```
281 pub fn new(nx: i32, ny: i32, transform: Transform) -> Result<Self> {
282 let c = cufft()?;
283 let cu = c.cufft_plan_2d()?;
284 let mut plan: cufftHandle = 0;
285 check(unsafe { cu(&mut plan, nx, ny, transform.raw()) })?;
286 Ok(Self { handle: plan })
287 }
288
289 /// Execute a complex-to-complex 2D transform.
290 pub fn exec_c2c(
291 &self,
292 input: &mut DeviceBuffer<Complex32>,
293 output: &mut DeviceBuffer<Complex32>,
294 direction: Direction,
295 ) -> Result<()> {
296 let c = cufft()?;
297 let cu = c.cufft_exec_c2c()?;
298 check(unsafe {
299 cu(
300 self.handle,
301 input.as_raw().0 as *mut cufftComplex,
302 output.as_raw().0 as *mut cufftComplex,
303 direction.raw(),
304 )
305 })
306 }
307
308 /// Raw handle.
309 #[inline]
310 pub fn as_raw(&self) -> cufftHandle {
311 self.handle
312 }
313}
314
315impl Drop for Plan2d {
316 fn drop(&mut self) {
317 if let Ok(c) = cufft() {
318 if let Ok(cu) = c.cufft_destroy() {
319 let _ = unsafe { cu(self.handle) };
320 }
321 }
322 }
323}
324
325/// Owned 3-D cuFFT plan.
326pub struct Plan3d {
327 handle: cufftHandle,
328}
329
330unsafe impl Send for Plan3d {}
331
332impl core::fmt::Debug for Plan3d {
333 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
334 f.debug_struct("Plan3d")
335 .field("handle", &self.handle)
336 .finish()
337 }
338}
339
340impl Plan3d {
341 /// Create a 3-D plan of dimensions `nx × ny × nz`.
342 pub fn new(nx: i32, ny: i32, nz: i32, transform: Transform) -> Result<Self> {
343 let c = cufft()?;
344 let cu = c.cufft_plan_3d()?;
345 let mut plan: cufftHandle = 0;
346 check(unsafe { cu(&mut plan, nx, ny, nz, transform.raw()) })?;
347 Ok(Self { handle: plan })
348 }
349
350 /// Bind subsequent exec calls on this plan to `stream`.
351 pub fn set_stream(&self, stream: &Stream) -> Result<()> {
352 let c = cufft()?;
353 let cu = c.cufft_set_stream()?;
354 check(unsafe { cu(self.handle, stream.as_raw() as _) })
355 }
356
357 /// Execute a 3-D complex-to-complex transform in the given direction.
358 pub fn exec_c2c(
359 &self,
360 input: &mut DeviceBuffer<Complex32>,
361 output: &mut DeviceBuffer<Complex32>,
362 direction: Direction,
363 ) -> Result<()> {
364 let c = cufft()?;
365 let cu = c.cufft_exec_c2c()?;
366 check(unsafe {
367 cu(
368 self.handle,
369 input.as_raw().0 as *mut cufftComplex,
370 output.as_raw().0 as *mut cufftComplex,
371 direction.raw(),
372 )
373 })
374 }
375
376 /// Raw `cufftHandle`. Use with care.
377 #[inline]
378 pub fn as_raw(&self) -> cufftHandle {
379 self.handle
380 }
381}
382
383impl Drop for Plan3d {
384 fn drop(&mut self) {
385 if let Ok(c) = cufft() {
386 if let Ok(cu) = c.cufft_destroy() {
387 let _ = unsafe { cu(self.handle) };
388 }
389 }
390 }
391}
392
393/// cuFFT library version, e.g. `11300` for cuFFT 11.3.0.
394pub fn version() -> Result<i32> {
395 let c = cufft()?;
396 let cu = c.cufft_get_version()?;
397 let mut v: core::ffi::c_int = 0;
398 check(unsafe { cu(&mut v) })?;
399 Ok(v)
400}
401
402// =======================================================================
403// Double-precision exec + PlanMany (batched) + XT multi-GPU
404// =======================================================================
405
406macro_rules! exec_z_impls {
407 ($plan:ty) => {
408 impl $plan {
409 /// Execute D → Z (double-precision R2C). Plan must have been
410 /// built with `Transform::D2Z`.
411 pub fn exec_d2z(
412 &self,
413 input: &mut DeviceBuffer<f64>,
414 output: &mut DeviceBuffer<Complex64>,
415 ) -> Result<()> {
416 let c = cufft()?;
417 let cu = c.cufft_exec_d2z()?;
418 check(unsafe {
419 cu(
420 self.handle,
421 input.as_raw().0 as *mut f64,
422 output.as_raw().0 as *mut cufftDoubleComplex,
423 )
424 })
425 }
426
427 /// Execute Z → D (double-precision C2R).
428 pub fn exec_z2d(
429 &self,
430 input: &mut DeviceBuffer<Complex64>,
431 output: &mut DeviceBuffer<f64>,
432 ) -> Result<()> {
433 let c = cufft()?;
434 let cu = c.cufft_exec_z2d()?;
435 check(unsafe {
436 cu(
437 self.handle,
438 input.as_raw().0 as *mut cufftDoubleComplex,
439 output.as_raw().0 as *mut f64,
440 )
441 })
442 }
443
444 /// Execute Z → Z (double-precision C2C). Direction passed at exec time.
445 pub fn exec_z2z(
446 &self,
447 input: &mut DeviceBuffer<Complex64>,
448 output: &mut DeviceBuffer<Complex64>,
449 direction: Direction,
450 ) -> Result<()> {
451 let c = cufft()?;
452 let cu = c.cufft_exec_z2z()?;
453 check(unsafe {
454 cu(
455 self.handle,
456 input.as_raw().0 as *mut cufftDoubleComplex,
457 output.as_raw().0 as *mut cufftDoubleComplex,
458 direction.raw(),
459 )
460 })
461 }
462 }
463 };
464}
465
466exec_z_impls!(Plan1d);
467exec_z_impls!(Plan2d);
468
469/// A batched / many-rank plan (`cufftPlanMany`). Handles arbitrary
470/// rank + advanced-data-layout transforms.
471#[derive(Debug)]
472pub struct PlanMany {
473 handle: cufftHandle,
474}
475
476impl PlanMany {
477 /// Construct a batched plan. `n[rank]` is the transform shape;
478 /// `inembed` / `onembed` are the actual memory layouts of in/out
479 /// (pass `None` for packed). `istride`/`ostride` are element strides
480 /// between successive elements; `idist`/`odist` are element strides
481 /// between successive batches.
482 ///
483 /// # Example
484 ///
485 /// 32 packed 1-D R2C transforms of length 256 (e.g., a STFT frame).
486 ///
487 /// ```no_run
488 /// use baracuda_driver::{Context, Device, DeviceBuffer};
489 /// use baracuda_cufft::{PlanMany, Transform};
490 /// use baracuda_types::Complex32;
491 ///
492 /// # fn demo() -> Result<(), Box<dyn std::error::Error>> {
493 /// let ctx = Context::new(&Device::get(0)?)?;
494 /// let n_per = 256i32;
495 /// let batch = 32i32;
496 /// let mut n = [n_per];
497 ///
498 /// // Packed contiguous layout: pass None for embeds, strides = 1, dist = transform length.
499 /// let plan = PlanMany::new(
500 /// /* rank */ 1,
501 /// /* n */ &mut n,
502 /// /* inemb */ None,
503 /// /* istr */ 1,
504 /// /* idist */ n_per,
505 /// /* onemb */ None,
506 /// /* ostr */ 1,
507 /// /* odist */ n_per / 2 + 1,
508 /// /* type */ Transform::R2C,
509 /// /* batch */ batch,
510 /// )?;
511 ///
512 /// let mut input: DeviceBuffer<f32> =
513 /// DeviceBuffer::zeros(&ctx, (n_per * batch) as usize)?;
514 /// let mut output: DeviceBuffer<Complex32> =
515 /// DeviceBuffer::new(&ctx, ((n_per / 2 + 1) * batch) as usize)?;
516 /// plan.exec_r2c(&mut input, &mut output)?;
517 /// # Ok(()) }
518 /// ```
519 #[allow(clippy::too_many_arguments)]
520 pub fn new(
521 rank: i32,
522 n: &mut [i32],
523 inembed: Option<&mut [i32]>,
524 istride: i32,
525 idist: i32,
526 onembed: Option<&mut [i32]>,
527 ostride: i32,
528 odist: i32,
529 ty: Transform,
530 batch: i32,
531 ) -> Result<Self> {
532 let c = cufft()?;
533 let cu = c.cufft_plan_many()?;
534 let mut h: cufftHandle = 0;
535 check(unsafe {
536 cu(
537 &mut h,
538 rank,
539 n.as_mut_ptr(),
540 inembed.map_or(core::ptr::null_mut(), |s| s.as_mut_ptr()),
541 istride,
542 idist,
543 onembed.map_or(core::ptr::null_mut(), |s| s.as_mut_ptr()),
544 ostride,
545 odist,
546 ty.raw(),
547 batch,
548 )
549 })?;
550 Ok(Self { handle: h })
551 }
552
553 /// Raw `cufftHandle`. Use with care.
554 #[inline]
555 pub fn as_raw(&self) -> cufftHandle {
556 self.handle
557 }
558
559 /// Bind the plan to a CUDA stream.
560 pub fn set_stream(&self, stream: &Stream) -> Result<()> {
561 let c = cufft()?;
562 let cu = c.cufft_set_stream()?;
563 check(unsafe { cu(self.handle, stream.as_raw() as _) })
564 }
565}
566
567impl Drop for PlanMany {
568 fn drop(&mut self) {
569 if let Ok(c) = cufft() {
570 if let Ok(cu) = c.cufft_destroy() {
571 let _ = unsafe { cu(self.handle) };
572 }
573 }
574 }
575}
576
577exec_z_impls!(PlanMany);
578
579impl PlanMany {
580 /// Execute R → C (single-precision R2C).
581 pub fn exec_r2c(
582 &self,
583 input: &mut DeviceBuffer<f32>,
584 output: &mut DeviceBuffer<Complex32>,
585 ) -> Result<()> {
586 let c = cufft()?;
587 let cu = c.cufft_exec_r2c()?;
588 check(unsafe {
589 cu(
590 self.handle,
591 input.as_raw().0 as *mut f32,
592 output.as_raw().0 as *mut cufftComplex,
593 )
594 })
595 }
596
597 /// Execute C → R (single-precision C2R).
598 pub fn exec_c2r(
599 &self,
600 input: &mut DeviceBuffer<Complex32>,
601 output: &mut DeviceBuffer<f32>,
602 ) -> Result<()> {
603 let c = cufft()?;
604 let cu = c.cufft_exec_c2r()?;
605 check(unsafe {
606 cu(
607 self.handle,
608 input.as_raw().0 as *mut cufftComplex,
609 output.as_raw().0 as *mut f32,
610 )
611 })
612 }
613
614 /// Execute C → C.
615 pub fn exec_c2c(
616 &self,
617 input: &mut DeviceBuffer<Complex32>,
618 output: &mut DeviceBuffer<Complex32>,
619 direction: Direction,
620 ) -> Result<()> {
621 let c = cufft()?;
622 let cu = c.cufft_exec_c2c()?;
623 check(unsafe {
624 cu(
625 self.handle,
626 input.as_raw().0 as *mut cufftComplex,
627 output.as_raw().0 as *mut cufftComplex,
628 direction.raw(),
629 )
630 })
631 }
632}
633
634/// Sizing estimates (workspace bytes) for a plan shape.
635pub fn estimate_1d(nx: i32, ty: Transform, batch: i32) -> Result<usize> {
636 let c = cufft()?;
637 let cu = c.cufft_estimate_1d()?;
638 let mut s: usize = 0;
639 check(unsafe { cu(nx, ty.raw(), batch, &mut s) })?;
640 Ok(s)
641}
642
643/// 2-D workspace-size estimate (bytes) for a plan of the given shape.
644/// Wraps `cufftEstimate2d`.
645pub fn estimate_2d(nx: i32, ny: i32, ty: Transform) -> Result<usize> {
646 let c = cufft()?;
647 let cu = c.cufft_estimate_2d()?;
648 let mut s: usize = 0;
649 check(unsafe { cu(nx, ny, ty.raw(), &mut s) })?;
650 Ok(s)
651}
652
653/// 3-D workspace-size estimate (bytes) for a plan of the given shape.
654/// Wraps `cufftEstimate3d`.
655pub fn estimate_3d(nx: i32, ny: i32, nz: i32, ty: Transform) -> Result<usize> {
656 let c = cufft()?;
657 let cu = c.cufft_estimate_3d()?;
658 let mut s: usize = 0;
659 check(unsafe { cu(nx, ny, nz, ty.raw(), &mut s) })?;
660 Ok(s)
661}
662
663/// Multi-GPU (XT) extension helpers. Use these to distribute a cuFFT
664/// plan across multiple GPUs via `cufftXtSetGPUs` + `cufftXtExec`.
665pub mod xt {
666 use super::*;
667
668 /// Spread a plan across `which_gpus` (CUDA device ordinals).
669 ///
670 /// # Safety
671 ///
672 /// `plan` must be a fresh (unexecuted) handle; all ordinals in
673 /// `which_gpus` must be live CUDA devices.
674 pub unsafe fn set_gpus(plan: cufftHandle, which_gpus: &mut [i32]) -> Result<()> { unsafe {
675 let c = cufft()?;
676 let cu = c.cufft_xt_set_gpus()?;
677 check(cu(plan, which_gpus.len() as i32, which_gpus.as_mut_ptr()))
678 }}
679
680 /// Allocate a multi-GPU `cudaLibXtDesc*` matching the plan.
681 /// Returns an opaque pointer that must be freed with [`free`].
682 ///
683 /// # Safety
684 ///
685 /// `plan` must have been configured with [`set_gpus`] first.
686 pub unsafe fn malloc(
687 plan: cufftHandle,
688 subformat: i32,
689 ) -> Result<*mut core::ffi::c_void> { unsafe {
690 let c = cufft()?;
691 let cu = c.cufft_xt_malloc()?;
692 let mut desc: *mut core::ffi::c_void = core::ptr::null_mut();
693 check(cu(plan, &mut desc, subformat))?;
694 Ok(desc)
695 }}
696
697 /// Free an XT descriptor from [`malloc`].
698 ///
699 /// # Safety
700 ///
701 /// `desc` must come from [`malloc`].
702 pub unsafe fn free(desc: *mut core::ffi::c_void) -> Result<()> { unsafe {
703 let c = cufft()?;
704 let cu = c.cufft_xt_free()?;
705 check(cu(desc))
706 }}
707
708 /// Multi-GPU memcpy between host / device / XT descriptors.
709 ///
710 /// # Safety
711 ///
712 /// Pointer kinds and `ty` must agree.
713 pub unsafe fn memcpy(
714 plan: cufftHandle,
715 dst: *mut core::ffi::c_void,
716 src: *mut core::ffi::c_void,
717 ty: i32,
718 ) -> Result<()> { unsafe {
719 let c = cufft()?;
720 let cu = c.cufft_xt_memcpy()?;
721 check(cu(plan, dst, src, ty))
722 }}
723
724 /// Execute the plan on its XT descriptors.
725 ///
726 /// # Safety
727 ///
728 /// `input` / `output` must be `cudaLibXtDesc*` pointers matching the plan.
729 pub unsafe fn exec_descriptor(
730 plan: cufftHandle,
731 input: *mut core::ffi::c_void,
732 output: *mut core::ffi::c_void,
733 direction: Direction,
734 ) -> Result<()> { unsafe {
735 let c = cufft()?;
736 let cu = c.cufft_xt_exec_descriptor()?;
737 check(cu(plan, input, output, direction.raw()))
738 }}
739}
740
741/// Set a user-allocated scratch work area (`cufftSetWorkArea`).
742///
743/// # Safety
744///
745/// `plan` must have `SetAutoAllocation(false)` first; `work_area` must
746/// be a live device pointer.
747pub unsafe fn set_work_area(plan: cufftHandle, work_area: *mut core::ffi::c_void) -> Result<()> { unsafe {
748 let c = cufft()?;
749 let cu = c.cufft_set_work_area()?;
750 check(cu(plan, work_area))
751}}
752
753/// Disable / re-enable automatic work-area allocation.
754pub fn set_auto_allocation(plan: cufftHandle, auto: bool) -> Result<()> {
755 let c = cufft()?;
756 let cu = c.cufft_set_auto_allocation()?;
757 check(unsafe { cu(plan, if auto { 1 } else { 0 }) })
758}
759
760/// Scratch bytes this plan currently needs.
761pub fn get_size(plan: cufftHandle) -> Result<usize> {
762 let c = cufft()?;
763 let cu = c.cufft_get_size()?;
764 let mut s: usize = 0;
765 check(unsafe { cu(plan, &mut s) })?;
766 Ok(s)
767}
768
769// ============================================================================
770// Two-step plan creation: cufftCreate + cufftMakePlan*
771// ============================================================================
772
773/// Generic cuFFT plan that supports the modern two-step creation flow:
774/// [`Plan::create`] gets you a fresh handle, then one of the
775/// `make_plan_*` methods configures the rank and reports the workspace
776/// size — useful when you want to allocate the work area yourself
777/// (call [`set_auto_allocation(false)`] before `make_plan_*`).
778///
779/// Unlike [`Plan1d`] / [`Plan2d`] / [`Plan3d`], this type can hold any
780/// rank, including the batched-many / 64-bit-many forms.
781pub struct Plan {
782 handle: cufftHandle,
783}
784
785unsafe impl Send for Plan {}
786
787impl core::fmt::Debug for Plan {
788 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
789 f.debug_struct("Plan")
790 .field("handle", &self.handle)
791 .finish()
792 }
793}
794
795impl Plan {
796 /// Allocate a fresh empty plan. The plan is unusable until you
797 /// finalize it with one of the `make_plan_*` methods.
798 pub fn create() -> Result<Self> {
799 let c = cufft()?;
800 let cu = c.cufft_create()?;
801 let mut plan: cufftHandle = 0;
802 check(unsafe { cu(&mut plan) })?;
803 Ok(Self { handle: plan })
804 }
805
806 /// Finalize as a 1-D plan of length `nx` and `batch` parallel
807 /// transforms. Returns the workspace size in bytes.
808 pub fn make_plan_1d(&self, nx: i32, transform: Transform, batch: i32) -> Result<usize> {
809 let c = cufft()?;
810 let cu = c.cufft_make_plan_1d()?;
811 let mut size: usize = 0;
812 check(unsafe { cu(self.handle, nx, transform.raw(), batch, &mut size) })?;
813 Ok(size)
814 }
815
816 /// Finalize as a 2-D plan. Returns workspace size in bytes.
817 pub fn make_plan_2d(&self, nx: i32, ny: i32, transform: Transform) -> Result<usize> {
818 let c = cufft()?;
819 let cu = c.cufft_make_plan_2d()?;
820 let mut size: usize = 0;
821 check(unsafe { cu(self.handle, nx, ny, transform.raw(), &mut size) })?;
822 Ok(size)
823 }
824
825 /// Finalize as a 3-D plan. Returns workspace size in bytes.
826 pub fn make_plan_3d(
827 &self,
828 nx: i32,
829 ny: i32,
830 nz: i32,
831 transform: Transform,
832 ) -> Result<usize> {
833 let c = cufft()?;
834 let cu = c.cufft_make_plan_3d()?;
835 let mut size: usize = 0;
836 check(unsafe { cu(self.handle, nx, ny, nz, transform.raw(), &mut size) })?;
837 Ok(size)
838 }
839
840 /// Finalize as a generic strided/batched plan. Returns workspace
841 /// size in bytes.
842 ///
843 /// # Safety
844 ///
845 /// `n`, `inembed`, `onembed` must be writable arrays of length
846 /// `rank` (cuFFT mutates them in some versions). Pass null for
847 /// `inembed` / `onembed` to use defaults.
848 #[allow(clippy::too_many_arguments)]
849 pub unsafe fn make_plan_many(
850 &self,
851 rank: i32,
852 n: &mut [i32],
853 inembed: *mut i32,
854 istride: i32,
855 idist: i32,
856 onembed: *mut i32,
857 ostride: i32,
858 odist: i32,
859 transform: Transform,
860 batch: i32,
861 ) -> Result<usize> { unsafe {
862 assert_eq!(n.len() as i32, rank, "n.len() must equal rank");
863 let c = cufft()?;
864 let cu = c.cufft_make_plan_many()?;
865 let mut size: usize = 0;
866 check(cu(
867 self.handle,
868 rank,
869 n.as_mut_ptr(),
870 inembed,
871 istride,
872 idist,
873 onembed,
874 ostride,
875 odist,
876 transform.raw(),
877 batch,
878 &mut size,
879 ))?;
880 Ok(size)
881 }}
882
883 /// 64-bit variant of [`make_plan_many`] — use this when any
884 /// dimension or stride exceeds `i32::MAX`.
885 ///
886 /// # Safety
887 ///
888 /// Same as [`make_plan_many`].
889 #[allow(clippy::too_many_arguments)]
890 pub unsafe fn make_plan_many64(
891 &self,
892 rank: i32,
893 n: &mut [i64],
894 inembed: *mut i64,
895 istride: i64,
896 idist: i64,
897 onembed: *mut i64,
898 ostride: i64,
899 odist: i64,
900 transform: Transform,
901 batch: i64,
902 ) -> Result<usize> { unsafe {
903 assert_eq!(n.len() as i32, rank, "n.len() must equal rank");
904 let c = cufft()?;
905 let cu = c.cufft_make_plan_many64()?;
906 let mut size: usize = 0;
907 check(cu(
908 self.handle,
909 rank,
910 n.as_mut_ptr(),
911 inembed,
912 istride,
913 idist,
914 onembed,
915 ostride,
916 odist,
917 transform.raw(),
918 batch,
919 &mut size,
920 ))?;
921 Ok(size)
922 }}
923
924 /// Bind subsequent exec calls to `stream`.
925 pub fn set_stream(&self, stream: &Stream) -> Result<()> {
926 let c = cufft()?;
927 let cu = c.cufft_set_stream()?;
928 check(unsafe { cu(self.handle, stream.as_raw() as _) })
929 }
930
931 /// Raw `cufftHandle`. Use with care.
932 #[inline]
933 pub fn as_raw(&self) -> cufftHandle {
934 self.handle
935 }
936}
937
938impl Drop for Plan {
939 fn drop(&mut self) {
940 if let Ok(c) = cufft() {
941 if let Ok(cu) = c.cufft_destroy() {
942 let _ = unsafe { cu(self.handle) };
943 }
944 }
945 }
946}
947
948// ============================================================================
949// Callback API (cufftXtSetCallback / Clear / SetCallbackSharedSize)
950// ============================================================================
951
952pub mod callback {
953 //! Pre/post callbacks attached to a cuFFT plan via the `cufftXt*`
954 //! callback entry points. The callback ABI is fixed by NVIDIA: each
955 //! callback receives the input/output element index, a caller-info
956 //! pointer, and the data; see the cuFFT reference for the exact
957 //! signatures by callback type. We expose only the raw setters
958 //! because the function-pointer types are PTX-shaped, not regular
959 //! `extern "C"` functions — they ship as device-side `__device__`
960 //! functions linked into the user's CUBIN.
961 use super::*;
962
963 /// Callback type values from the cuFFT header
964 /// (`cufftXtCallbackType`).
965 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
966 #[repr(i32)]
967 pub enum CallbackType {
968 /// Load callback for complex single precision.
969 LoadComplex = 0,
970 /// Load callback for complex double precision.
971 LoadDoubleComplex = 1,
972 /// Load callback for real single precision.
973 LoadReal = 2,
974 /// Load callback for real double precision.
975 LoadDoubleReal = 3,
976 /// Store callback for complex single precision.
977 StoreComplex = 4,
978 /// Store callback for complex double precision.
979 StoreDoubleComplex = 5,
980 /// Store callback for real single precision.
981 StoreReal = 6,
982 /// Store callback for real double precision.
983 StoreDoubleReal = 7,
984 }
985
986 /// Attach a load/store callback to the plan. `callback_routine` is
987 /// an array of device function pointers — one per GPU for
988 /// multi-GPU plans, otherwise a single-element array.
989 /// `caller_info` is parallel; pass null for "no caller info".
990 ///
991 /// # Safety
992 ///
993 /// `callback_routine[i]` must be a `__device__` function with the
994 /// signature cuFFT expects for `cb_type` (see the cuFFT reference).
995 /// The routine and `caller_info` must outlive every plan exec call.
996 pub unsafe fn set(
997 plan: cufftHandle,
998 callback_routine: &mut [*mut core::ffi::c_void],
999 cb_type: CallbackType,
1000 caller_info: &mut [*mut core::ffi::c_void],
1001 ) -> Result<()> { unsafe {
1002 assert_eq!(
1003 callback_routine.len(),
1004 caller_info.len(),
1005 "callback_routine and caller_info must have the same length"
1006 );
1007 let c = cufft()?;
1008 let cu = c.cufft_xt_set_callback()?;
1009 check(cu(
1010 plan,
1011 callback_routine.as_mut_ptr(),
1012 cb_type as i32,
1013 caller_info.as_mut_ptr(),
1014 ))
1015 }}
1016
1017 /// Detach any previously set callback of `cb_type`.
1018 pub fn clear(plan: cufftHandle, cb_type: CallbackType) -> Result<()> {
1019 let c = cufft()?;
1020 let cu = c.cufft_xt_clear_callback()?;
1021 check(unsafe { cu(plan, cb_type as i32) })
1022 }
1023
1024 /// Reserve `shared_size` bytes of dynamic shared memory per kernel
1025 /// for the callback. Maximum permissible value is GPU-dependent.
1026 pub fn set_shared_size(
1027 plan: cufftHandle,
1028 cb_type: CallbackType,
1029 shared_size: usize,
1030 ) -> Result<()> {
1031 let c = cufft()?;
1032 let cu = c.cufft_xt_set_callback_shared_size()?;
1033 check(unsafe { cu(plan, cb_type as i32, shared_size) })
1034 }
1035}