1use core::cell::Cell;
20use core::ffi::c_void;
21use core::marker::PhantomData;
22
23use baracuda_cutlass::{Error, Result};
24use baracuda_driver::Stream;
25use baracuda_kernels_sys::{
26 baracuda_kernels_scale_inplace_real_f32_run, baracuda_kernels_scale_inplace_real_f64_run,
27 cufftComplex, cufftDestroy, cufftDoubleComplex, cufftExecC2R, cufftExecD2Z, cufftExecR2C,
28 cufftExecZ2D, cufftHandle, cufftPlan1d, cufftSetStream, CUFFT_C2R, CUFFT_D2Z, CUFFT_R2C,
29 CUFFT_Z2D,
30};
31use baracuda_kernels_types::{
32 ArchSku, BackendKind, Complex32, Complex64, Element, ElementKind, FftKind, KernelSku,
33 MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
34};
35
36use super::fft::{cufft_to_status, map_status};
37
38const HANDLE_UNINIT: cufftHandle = -1;
39
40#[derive(Copy, Clone, Debug)]
46pub struct RfftDescriptor {
47 pub n: i32,
50 pub batch: i32,
52 pub element: ElementKind,
55}
56
57pub struct RfftArgs<'a, T: Element, C: Element> {
62 pub x: TensorRef<'a, T, 2>,
64 pub y: TensorMut<'a, C, 2>,
66}
67
68pub struct RfftPlan<T: Element> {
90 desc: RfftDescriptor,
91 sku: KernelSku,
92 handle: Cell<cufftHandle>,
93 _marker: PhantomData<T>,
94}
95
96impl<T: Element> RfftPlan<T> {
97 pub fn select(_stream: &Stream, desc: &RfftDescriptor, _pref: PlanPreference) -> Result<Self> {
99 if desc.element != T::KIND {
100 return Err(Error::Unsupported(
101 "baracuda-kernels::RfftPlan: descriptor.element != T::KIND",
102 ));
103 }
104 if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
105 return Err(Error::Unsupported(
106 "baracuda-kernels::RfftPlan: R2C FFT supports f32 + f64 only",
107 ));
108 }
109 if desc.n <= 0 {
110 return Err(Error::InvalidProblem(
111 "baracuda-kernels::RfftPlan: n must be > 0",
112 ));
113 }
114 if desc.batch <= 0 {
115 return Err(Error::InvalidProblem(
116 "baracuda-kernels::RfftPlan: batch must be > 0",
117 ));
118 }
119 let math_precision = match T::KIND {
120 ElementKind::F64 => MathPrecision::F64,
121 _ => MathPrecision::F32,
122 };
123 let aux = match T::KIND {
124 ElementKind::F32 => Some(ElementKind::Complex32),
125 ElementKind::F64 => Some(ElementKind::Complex64),
126 _ => None,
127 };
128 let precision_guarantee = PrecisionGuarantee {
129 math_precision,
130 accumulator: T::KIND,
131 bit_stable_on_same_hardware: false,
132 deterministic: true,
133 };
134 let sku = KernelSku {
135 category: OpCategory::Fft,
136 op: FftKind::Rfft as u16,
137 element: T::KIND,
138 aux_element: aux,
139 layout: None,
140 epilogue: None,
141 arch: ArchSku::Sm80,
142 backend: BackendKind::Cufft,
143 precision_guarantee,
144 };
145 Ok(Self {
146 desc: *desc,
147 sku,
148 handle: Cell::new(HANDLE_UNINIT),
149 _marker: PhantomData,
150 })
151 }
152
153 #[inline]
155 pub fn sku(&self) -> KernelSku {
156 self.sku
157 }
158
159 #[inline]
161 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
162 self.sku.precision_guarantee
163 }
164
165 #[inline]
168 pub fn workspace_size(&self) -> usize {
169 0
170 }
171
172 fn ensure_handle(&self) -> Result<cufftHandle> {
173 let h = self.handle.get();
174 if h != HANDLE_UNINIT {
175 return Ok(h);
176 }
177 let fft_type = match T::KIND {
178 ElementKind::F32 => CUFFT_R2C,
179 ElementKind::F64 => CUFFT_D2Z,
180 _ => unreachable!("select() gates on F32 / F64"),
181 };
182 let mut handle: cufftHandle = HANDLE_UNINIT;
183 let status = unsafe {
184 cufftPlan1d(
185 &mut handle as *mut _,
186 self.desc.n,
187 fft_type,
188 self.desc.batch,
189 )
190 };
191 if status != 0 {
192 return Err(Error::CutlassInternal(cufft_to_status(status)));
193 }
194 self.handle.set(handle);
195 Ok(handle)
196 }
197
198 fn bind_stream(&self, handle: cufftHandle, stream: &Stream) -> Result<()> {
199 let stream_ptr = stream.as_raw() as *mut c_void;
200 let status = unsafe { cufftSetStream(handle, stream_ptr) };
201 if status != 0 {
202 return Err(Error::CutlassInternal(cufft_to_status(status)));
203 }
204 Ok(())
205 }
206}
207
208impl RfftPlan<f32> {
209 pub fn run(
211 &self,
212 stream: &Stream,
213 _workspace: Workspace<'_>,
214 args: RfftArgs<'_, f32, Complex32>,
215 ) -> Result<()> {
216 let n = self.desc.n;
217 let batch = self.desc.batch;
218 let in_shape = [batch, n];
219 let out_shape = [batch, n / 2 + 1];
220 if args.x.shape != in_shape {
221 return Err(Error::InvalidProblem(
222 "baracuda-kernels::RfftPlan<f32>: x shape != [batch, n]",
223 ));
224 }
225 if args.y.shape != out_shape {
226 return Err(Error::InvalidProblem(
227 "baracuda-kernels::RfftPlan<f32>: y shape != [batch, n/2 + 1]",
228 ));
229 }
230 let in_numel = (batch as i64) * (n as i64);
231 let out_numel = (batch as i64) * ((n / 2 + 1) as i64);
232 if (args.x.data.len() as i64) < in_numel {
233 return Err(Error::BufferTooSmall {
234 needed: in_numel as usize,
235 got: args.x.data.len(),
236 });
237 }
238 if (args.y.data.len() as i64) < out_numel {
239 return Err(Error::BufferTooSmall {
240 needed: out_numel as usize,
241 got: args.y.data.len(),
242 });
243 }
244 if in_numel == 0 {
245 return Ok(());
246 }
247
248 let handle = self.ensure_handle()?;
249 self.bind_stream(handle, stream)?;
250
251 let idata = args.x.data.as_raw().0 as *mut f32;
252 let odata = args.y.data.as_raw().0 as *mut cufftComplex;
253 let status = unsafe { cufftExecR2C(handle, idata, odata) };
254 if status != 0 {
255 return Err(Error::CutlassInternal(cufft_to_status(status)));
256 }
257 Ok(())
258 }
259}
260
261impl RfftPlan<f64> {
262 pub fn run(
264 &self,
265 stream: &Stream,
266 _workspace: Workspace<'_>,
267 args: RfftArgs<'_, f64, Complex64>,
268 ) -> Result<()> {
269 let n = self.desc.n;
270 let batch = self.desc.batch;
271 let in_shape = [batch, n];
272 let out_shape = [batch, n / 2 + 1];
273 if args.x.shape != in_shape {
274 return Err(Error::InvalidProblem(
275 "baracuda-kernels::RfftPlan<f64>: x shape != [batch, n]",
276 ));
277 }
278 if args.y.shape != out_shape {
279 return Err(Error::InvalidProblem(
280 "baracuda-kernels::RfftPlan<f64>: y shape != [batch, n/2 + 1]",
281 ));
282 }
283 let in_numel = (batch as i64) * (n as i64);
284 let out_numel = (batch as i64) * ((n / 2 + 1) as i64);
285 if (args.x.data.len() as i64) < in_numel {
286 return Err(Error::BufferTooSmall {
287 needed: in_numel as usize,
288 got: args.x.data.len(),
289 });
290 }
291 if (args.y.data.len() as i64) < out_numel {
292 return Err(Error::BufferTooSmall {
293 needed: out_numel as usize,
294 got: args.y.data.len(),
295 });
296 }
297 if in_numel == 0 {
298 return Ok(());
299 }
300
301 let handle = self.ensure_handle()?;
302 self.bind_stream(handle, stream)?;
303
304 let idata = args.x.data.as_raw().0 as *mut f64;
305 let odata = args.y.data.as_raw().0 as *mut cufftDoubleComplex;
306 let status = unsafe { cufftExecD2Z(handle, idata, odata) };
307 if status != 0 {
308 return Err(Error::CutlassInternal(cufft_to_status(status)));
309 }
310 Ok(())
311 }
312}
313
314impl<T: Element> Drop for RfftPlan<T> {
315 fn drop(&mut self) {
316 let h = self.handle.get();
317 if h != HANDLE_UNINIT {
318 unsafe {
319 let _ = cufftDestroy(h);
320 }
321 self.handle.set(HANDLE_UNINIT);
322 }
323 }
324}
325
326#[derive(Copy, Clone, Debug)]
338pub struct IrfftDescriptor {
339 pub n: i32,
341 pub batch: i32,
343 pub element: ElementKind,
345}
346
347pub struct IrfftArgs<'a, T: Element, C: Element> {
352 pub x: TensorRef<'a, C, 2>,
354 pub y: TensorMut<'a, T, 2>,
356}
357
358pub struct IrfftPlan<T: Element> {
382 desc: IrfftDescriptor,
383 sku: KernelSku,
384 handle: Cell<cufftHandle>,
385 _marker: PhantomData<T>,
386}
387
388impl<T: Element> IrfftPlan<T> {
389 pub fn select(
391 _stream: &Stream,
392 desc: &IrfftDescriptor,
393 _pref: PlanPreference,
394 ) -> Result<Self> {
395 if desc.element != T::KIND {
396 return Err(Error::Unsupported(
397 "baracuda-kernels::IrfftPlan: descriptor.element != T::KIND",
398 ));
399 }
400 if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
401 return Err(Error::Unsupported(
402 "baracuda-kernels::IrfftPlan: C2R FFT supports f32 + f64 only",
403 ));
404 }
405 if desc.n <= 0 {
406 return Err(Error::InvalidProblem(
407 "baracuda-kernels::IrfftPlan: n must be > 0",
408 ));
409 }
410 if desc.batch <= 0 {
411 return Err(Error::InvalidProblem(
412 "baracuda-kernels::IrfftPlan: batch must be > 0",
413 ));
414 }
415 let math_precision = match T::KIND {
416 ElementKind::F64 => MathPrecision::F64,
417 _ => MathPrecision::F32,
418 };
419 let aux = match T::KIND {
420 ElementKind::F32 => Some(ElementKind::Complex32),
421 ElementKind::F64 => Some(ElementKind::Complex64),
422 _ => None,
423 };
424 let precision_guarantee = PrecisionGuarantee {
425 math_precision,
426 accumulator: T::KIND,
427 bit_stable_on_same_hardware: false,
428 deterministic: true,
429 };
430 let sku = KernelSku {
431 category: OpCategory::Fft,
432 op: FftKind::Irfft as u16,
433 element: T::KIND,
434 aux_element: aux,
435 layout: None,
436 epilogue: None,
437 arch: ArchSku::Sm80,
438 backend: BackendKind::Cufft,
439 precision_guarantee,
440 };
441 Ok(Self {
442 desc: *desc,
443 sku,
444 handle: Cell::new(HANDLE_UNINIT),
445 _marker: PhantomData,
446 })
447 }
448
449 #[inline]
451 pub fn sku(&self) -> KernelSku {
452 self.sku
453 }
454
455 #[inline]
457 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
458 self.sku.precision_guarantee
459 }
460
461 #[inline]
463 pub fn workspace_size(&self) -> usize {
464 0
465 }
466
467 fn ensure_handle(&self) -> Result<cufftHandle> {
468 let h = self.handle.get();
469 if h != HANDLE_UNINIT {
470 return Ok(h);
471 }
472 let fft_type = match T::KIND {
473 ElementKind::F32 => CUFFT_C2R,
474 ElementKind::F64 => CUFFT_Z2D,
475 _ => unreachable!("select() gates on F32 / F64"),
476 };
477 let mut handle: cufftHandle = HANDLE_UNINIT;
478 let status = unsafe {
479 cufftPlan1d(
480 &mut handle as *mut _,
481 self.desc.n,
482 fft_type,
483 self.desc.batch,
484 )
485 };
486 if status != 0 {
487 return Err(Error::CutlassInternal(cufft_to_status(status)));
488 }
489 self.handle.set(handle);
490 Ok(handle)
491 }
492
493 fn bind_stream(&self, handle: cufftHandle, stream: &Stream) -> Result<()> {
494 let stream_ptr = stream.as_raw() as *mut c_void;
495 let status = unsafe { cufftSetStream(handle, stream_ptr) };
496 if status != 0 {
497 return Err(Error::CutlassInternal(cufft_to_status(status)));
498 }
499 Ok(())
500 }
501}
502
503impl IrfftPlan<f32> {
504 pub fn run(
507 &self,
508 stream: &Stream,
509 _workspace: Workspace<'_>,
510 args: IrfftArgs<'_, f32, Complex32>,
511 ) -> Result<()> {
512 let n = self.desc.n;
513 let batch = self.desc.batch;
514 let in_shape = [batch, n / 2 + 1];
515 let out_shape = [batch, n];
516 if args.x.shape != in_shape {
517 return Err(Error::InvalidProblem(
518 "baracuda-kernels::IrfftPlan<f32>: x shape != [batch, n/2 + 1]",
519 ));
520 }
521 if args.y.shape != out_shape {
522 return Err(Error::InvalidProblem(
523 "baracuda-kernels::IrfftPlan<f32>: y shape != [batch, n]",
524 ));
525 }
526 let in_numel = (batch as i64) * ((n / 2 + 1) as i64);
527 let out_numel = (batch as i64) * (n as i64);
528 if (args.x.data.len() as i64) < in_numel {
529 return Err(Error::BufferTooSmall {
530 needed: in_numel as usize,
531 got: args.x.data.len(),
532 });
533 }
534 if (args.y.data.len() as i64) < out_numel {
535 return Err(Error::BufferTooSmall {
536 needed: out_numel as usize,
537 got: args.y.data.len(),
538 });
539 }
540 if out_numel == 0 {
541 return Ok(());
542 }
543
544 let handle = self.ensure_handle()?;
545 self.bind_stream(handle, stream)?;
546
547 let idata = args.x.data.as_raw().0 as *mut cufftComplex;
548 let odata = args.y.data.as_raw().0 as *mut f32;
549 let status = unsafe { cufftExecC2R(handle, idata, odata) };
550 if status != 0 {
551 return Err(Error::CutlassInternal(cufft_to_status(status)));
552 }
553
554 let scale = 1.0_f32 / (n as f32);
556 let stream_ptr = stream.as_raw() as *mut c_void;
557 let s = unsafe {
558 baracuda_kernels_scale_inplace_real_f32_run(
559 out_numel,
560 scale,
561 odata as *mut c_void,
562 core::ptr::null_mut(),
563 0,
564 stream_ptr,
565 )
566 };
567 map_status(s)
568 }
569}
570
571impl IrfftPlan<f64> {
572 pub fn run(
574 &self,
575 stream: &Stream,
576 _workspace: Workspace<'_>,
577 args: IrfftArgs<'_, f64, Complex64>,
578 ) -> Result<()> {
579 let n = self.desc.n;
580 let batch = self.desc.batch;
581 let in_shape = [batch, n / 2 + 1];
582 let out_shape = [batch, n];
583 if args.x.shape != in_shape {
584 return Err(Error::InvalidProblem(
585 "baracuda-kernels::IrfftPlan<f64>: x shape != [batch, n/2 + 1]",
586 ));
587 }
588 if args.y.shape != out_shape {
589 return Err(Error::InvalidProblem(
590 "baracuda-kernels::IrfftPlan<f64>: y shape != [batch, n]",
591 ));
592 }
593 let in_numel = (batch as i64) * ((n / 2 + 1) as i64);
594 let out_numel = (batch as i64) * (n as i64);
595 if (args.x.data.len() as i64) < in_numel {
596 return Err(Error::BufferTooSmall {
597 needed: in_numel as usize,
598 got: args.x.data.len(),
599 });
600 }
601 if (args.y.data.len() as i64) < out_numel {
602 return Err(Error::BufferTooSmall {
603 needed: out_numel as usize,
604 got: args.y.data.len(),
605 });
606 }
607 if out_numel == 0 {
608 return Ok(());
609 }
610
611 let handle = self.ensure_handle()?;
612 self.bind_stream(handle, stream)?;
613
614 let idata = args.x.data.as_raw().0 as *mut cufftDoubleComplex;
615 let odata = args.y.data.as_raw().0 as *mut f64;
616 let status = unsafe { cufftExecZ2D(handle, idata, odata) };
617 if status != 0 {
618 return Err(Error::CutlassInternal(cufft_to_status(status)));
619 }
620
621 let scale = 1.0_f64 / (n as f64);
622 let stream_ptr = stream.as_raw() as *mut c_void;
623 let s = unsafe {
624 baracuda_kernels_scale_inplace_real_f64_run(
625 out_numel,
626 scale,
627 odata as *mut c_void,
628 core::ptr::null_mut(),
629 0,
630 stream_ptr,
631 )
632 };
633 map_status(s)
634 }
635}
636
637impl<T: Element> Drop for IrfftPlan<T> {
638 fn drop(&mut self) {
639 let h = self.handle.get();
640 if h != HANDLE_UNINIT {
641 unsafe {
642 let _ = cufftDestroy(h);
643 }
644 self.handle.set(HANDLE_UNINIT);
645 }
646 }
647}