1#![warn(missing_debug_implementations)]
47
48use baracuda_cudnn_sys::{
49 cudnn, cudnnActivationDescriptor_t, cudnnActivationMode_t, cudnnAttnDescriptor_t,
50 cudnnBackendAttributeName_t, cudnnBackendAttributeType_t, cudnnBackendDescriptorType_t,
51 cudnnBackendDescriptor_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t,
52 cudnnConvolutionBwdDataAlgo_t, cudnnConvolutionBwdFilterAlgo_t,
53 cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, cudnnConvolutionMode_t,
54 cudnnDataType_t, cudnnDropoutDescriptor_t, cudnnFilterDescriptor_t,
55 cudnnHandle_t, cudnnIndicesType_t, cudnnLRNDescriptor_t, cudnnMathType_t, cudnnNanPropagation_t,
56 cudnnNormAlgo_t, cudnnNormMode_t, cudnnNormOps_t, cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t,
57 cudnnPoolingDescriptor_t, cudnnPoolingMode_t, cudnnReduceTensorDescriptor_t,
58 cudnnReduceTensorIndices_t, cudnnReduceTensorOp_t, cudnnReorderType_t,
59 cudnnSeqDataDescriptor_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, cudnnStatus_t,
60 cudnnTensorDescriptor_t, cudnnTensorFormat_t,
61};
62use baracuda_driver::{DeviceBuffer, Stream};
63use baracuda_types::DeviceRepr;
64
65pub type Error = baracuda_core::Error<cudnnStatus_t>;
67pub type Result<T, E = Error> = core::result::Result<T, E>;
69
70#[inline]
71fn check(status: cudnnStatus_t) -> Result<()> {
72 Error::check(status)
73}
74
75pub struct Handle {
77 handle: cudnnHandle_t,
78}
79
80unsafe impl Send for Handle {}
81
82impl core::fmt::Debug for Handle {
83 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
84 f.debug_struct("cudnn::Handle")
85 .field("handle", &self.handle)
86 .finish()
87 }
88}
89
90impl Handle {
91 pub fn new() -> Result<Self> {
93 let c = cudnn()?;
94 let cu = c.cudnn_create()?;
95 let mut h: cudnnHandle_t = core::ptr::null_mut();
96 check(unsafe { cu(&mut h) })?;
97 Ok(Self { handle: h })
98 }
99
100 pub fn set_stream(&self, stream: &Stream) -> Result<()> {
102 let c = cudnn()?;
103 let cu = c.cudnn_set_stream()?;
104 check(unsafe { cu(self.handle, stream.as_raw() as _) })
105 }
106
107 #[inline]
109 pub fn as_raw(&self) -> cudnnHandle_t {
110 self.handle
111 }
112}
113
114impl Drop for Handle {
115 fn drop(&mut self) {
116 if let Ok(c) = cudnn() {
117 if let Ok(cu) = c.cudnn_destroy() {
118 let _ = unsafe { cu(self.handle) };
119 }
120 }
121 }
122}
123
124pub fn version() -> Result<usize> {
128 let c = cudnn()?;
129 let cu = c.cudnn_get_version()?;
130 Ok(unsafe { cu() })
132}
133
134#[derive(Copy, Clone, Debug, Eq, PartialEq)]
136pub enum DType {
137 F32,
139 F64,
141 F16,
143 BF16,
145 I8,
147 I32,
149}
150
151impl DType {
152 fn raw(self) -> cudnnDataType_t {
153 match self {
154 DType::F32 => cudnnDataType_t::Float,
155 DType::F64 => cudnnDataType_t::Double,
156 DType::F16 => cudnnDataType_t::Half,
157 DType::BF16 => cudnnDataType_t::BFloat16,
158 DType::I8 => cudnnDataType_t::Int8,
159 DType::I32 => cudnnDataType_t::Int32,
160 }
161 }
162}
163
164pub trait CudnnDataType: DeviceRepr + Copy + 'static {
186 const DTYPE: DType;
188}
189
190impl CudnnDataType for f32 {
191 const DTYPE: DType = DType::F32;
192}
193impl CudnnDataType for f64 {
194 const DTYPE: DType = DType::F64;
195}
196impl CudnnDataType for baracuda_types::Half {
197 const DTYPE: DType = DType::F16;
198}
199impl CudnnDataType for baracuda_types::BFloat16 {
200 const DTYPE: DType = DType::BF16;
201}
202impl CudnnDataType for i8 {
203 const DTYPE: DType = DType::I8;
204}
205impl CudnnDataType for i32 {
206 const DTYPE: DType = DType::I32;
207}
208
209#[cfg(feature = "half-crate")]
214impl CudnnDataType for half::f16 {
215 const DTYPE: DType = DType::F16;
216}
217#[cfg(feature = "half-crate")]
218impl CudnnDataType for half::bf16 {
219 const DTYPE: DType = DType::BF16;
220}
221
222#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
224pub enum TensorFormat {
225 #[default]
227 Nchw,
228 Nhwc,
230}
231
232impl TensorFormat {
233 fn raw(self) -> cudnnTensorFormat_t {
234 match self {
235 TensorFormat::Nchw => cudnnTensorFormat_t::Nchw,
236 TensorFormat::Nhwc => cudnnTensorFormat_t::Nhwc,
237 }
238 }
239}
240
241pub struct TensorDescriptor {
243 desc: cudnnTensorDescriptor_t,
244}
245
246unsafe impl Send for TensorDescriptor {}
247
248impl core::fmt::Debug for TensorDescriptor {
249 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
250 f.debug_struct("TensorDescriptor")
251 .field("desc", &self.desc)
252 .finish_non_exhaustive()
253 }
254}
255
256impl TensorDescriptor {
257 pub fn new_4d(
259 format: TensorFormat,
260 dtype: DType,
261 n: i32,
262 c: i32,
263 h: i32,
264 w: i32,
265 ) -> Result<Self> {
266 let cu_crate = cudnn()?;
267 let create = cu_crate.cudnn_create_tensor_descriptor()?;
268 let set = cu_crate.cudnn_set_tensor_4d_descriptor()?;
269 let mut desc: cudnnTensorDescriptor_t = core::ptr::null_mut();
270 check(unsafe { create(&mut desc) })?;
271 let this = Self { desc };
272 check(unsafe { set(this.desc, format.raw(), dtype.raw(), n, c, h, w) })?;
273 Ok(this)
274 }
275
276 pub fn new_nd(dtype: DType, dims: &[i32], strides: &[i32]) -> Result<Self> {
280 assert_eq!(
281 dims.len(),
282 strides.len(),
283 "dims/strides length mismatch for Nd tensor descriptor"
284 );
285 let cu_crate = cudnn()?;
286 let create = cu_crate.cudnn_create_tensor_descriptor()?;
287 let set = cu_crate.cudnn_set_tensor_nd_descriptor()?;
288 let mut desc: cudnnTensorDescriptor_t = core::ptr::null_mut();
289 check(unsafe { create(&mut desc) })?;
290 let this = Self { desc };
291 check(unsafe {
292 set(
293 this.desc,
294 dtype.raw(),
295 dims.len() as core::ffi::c_int,
296 dims.as_ptr(),
297 strides.as_ptr(),
298 )
299 })?;
300 Ok(this)
301 }
302
303 #[inline]
305 pub fn as_raw(&self) -> cudnnTensorDescriptor_t {
306 self.desc
307 }
308}
309
310impl Drop for TensorDescriptor {
311 fn drop(&mut self) {
312 if let Ok(c) = cudnn() {
313 if let Ok(cu) = c.cudnn_destroy_tensor_descriptor() {
314 let _ = unsafe { cu(self.desc) };
315 }
316 }
317 }
318}
319
320#[derive(Copy, Clone, Debug, Eq, PartialEq)]
322pub enum ActivationMode {
323 Relu,
325 Sigmoid,
327 Tanh,
329 ClippedRelu,
331 Elu,
333 Identity,
335 Swish,
337}
338
339impl ActivationMode {
340 fn raw(self) -> cudnnActivationMode_t {
341 match self {
342 ActivationMode::Relu => cudnnActivationMode_t::Relu,
343 ActivationMode::Sigmoid => cudnnActivationMode_t::Sigmoid,
344 ActivationMode::Tanh => cudnnActivationMode_t::Tanh,
345 ActivationMode::ClippedRelu => cudnnActivationMode_t::ClippedRelu,
346 ActivationMode::Elu => cudnnActivationMode_t::Elu,
347 ActivationMode::Identity => cudnnActivationMode_t::Identity,
348 ActivationMode::Swish => cudnnActivationMode_t::Swish,
349 }
350 }
351}
352
353pub struct ActivationDescriptor {
355 desc: cudnnActivationDescriptor_t,
356}
357
358unsafe impl Send for ActivationDescriptor {}
359
360impl core::fmt::Debug for ActivationDescriptor {
361 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
362 f.debug_struct("ActivationDescriptor")
363 .field("desc", &self.desc)
364 .finish_non_exhaustive()
365 }
366}
367
368impl ActivationDescriptor {
369 pub fn new(mode: ActivationMode, coef: f64) -> Result<Self> {
372 let c = cudnn()?;
373 let create = c.cudnn_create_activation_descriptor()?;
374 let set = c.cudnn_set_activation_descriptor()?;
375 let mut desc: cudnnActivationDescriptor_t = core::ptr::null_mut();
376 check(unsafe { create(&mut desc) })?;
377 let this = Self { desc };
378 check(unsafe {
379 set(
380 this.desc,
381 mode.raw(),
382 cudnnNanPropagation_t::PropagateNan,
383 coef,
384 )
385 })?;
386 Ok(this)
387 }
388
389 #[inline]
391 pub fn as_raw(&self) -> cudnnActivationDescriptor_t {
392 self.desc
393 }
394}
395
396impl Drop for ActivationDescriptor {
397 fn drop(&mut self) {
398 if let Ok(c) = cudnn() {
399 if let Ok(cu) = c.cudnn_destroy_activation_descriptor() {
400 let _ = unsafe { cu(self.desc) };
401 }
402 }
403 }
404}
405
406#[allow(clippy::too_many_arguments)]
436pub fn activation_forward<T: DeviceRepr>(
437 handle: &Handle,
438 activation: &ActivationDescriptor,
439 alpha: f32,
440 x_desc: &TensorDescriptor,
441 x: &DeviceBuffer<T>,
442 beta: f32,
443 y_desc: &TensorDescriptor,
444 y: &mut DeviceBuffer<T>,
445) -> Result<()> {
446 let c = cudnn()?;
447 let cu = c.cudnn_activation_forward()?;
448 check(unsafe {
449 cu(
450 handle.handle,
451 activation.desc,
452 &alpha as *const f32 as *const core::ffi::c_void,
453 x_desc.desc,
454 x.as_raw().0 as *const core::ffi::c_void,
455 &beta as *const f32 as *const core::ffi::c_void,
456 y_desc.desc,
457 y.as_raw().0 as *mut core::ffi::c_void,
458 )
459 })
460}
461
462pub struct FilterDescriptor {
466 desc: cudnnFilterDescriptor_t,
467}
468
469unsafe impl Send for FilterDescriptor {}
470
471impl core::fmt::Debug for FilterDescriptor {
472 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
473 f.debug_struct("FilterDescriptor")
474 .field("desc", &self.desc)
475 .finish_non_exhaustive()
476 }
477}
478
479impl FilterDescriptor {
480 pub fn new_4d(
483 format: TensorFormat,
484 dtype: DType,
485 k: i32,
486 c: i32,
487 h: i32,
488 w: i32,
489 ) -> Result<Self> {
490 let cu = cudnn()?;
491 let create = cu.cudnn_create_filter_descriptor()?;
492 let set = cu.cudnn_set_filter_4d_descriptor()?;
493 let mut desc: cudnnFilterDescriptor_t = core::ptr::null_mut();
494 check(unsafe { create(&mut desc) })?;
495 let this = Self { desc };
496 check(unsafe { set(this.desc, dtype.raw(), format.raw(), k, c, h, w) })?;
497 Ok(this)
498 }
499
500 #[inline]
502 pub fn as_raw(&self) -> cudnnFilterDescriptor_t {
503 self.desc
504 }
505}
506
507impl Drop for FilterDescriptor {
508 fn drop(&mut self) {
509 if let Ok(c) = cudnn() {
510 if let Ok(cu) = c.cudnn_destroy_filter_descriptor() {
511 let _ = unsafe { cu(self.desc) };
512 }
513 }
514 }
515}
516
517#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
519pub enum ConvMode {
520 Convolution,
522 #[default]
524 CrossCorrelation,
525}
526
527impl ConvMode {
528 fn raw(self) -> cudnnConvolutionMode_t {
529 match self {
530 ConvMode::Convolution => cudnnConvolutionMode_t::Convolution,
531 ConvMode::CrossCorrelation => cudnnConvolutionMode_t::CrossCorrelation,
532 }
533 }
534}
535
536#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
540pub enum FwdAlgo {
541 #[default]
543 ImplicitGemm,
544 ImplicitPrecompGemm,
546 Gemm,
548 Direct,
550 Fft,
552 FftTiling,
554 Winograd,
556 WinogradNonfused,
558}
559
560impl FwdAlgo {
561 fn raw(self) -> cudnnConvolutionFwdAlgo_t {
562 match self {
563 FwdAlgo::ImplicitGemm => cudnnConvolutionFwdAlgo_t::ImplicitGemm,
564 FwdAlgo::ImplicitPrecompGemm => cudnnConvolutionFwdAlgo_t::ImplicitPrecompGemm,
565 FwdAlgo::Gemm => cudnnConvolutionFwdAlgo_t::Gemm,
566 FwdAlgo::Direct => cudnnConvolutionFwdAlgo_t::Direct,
567 FwdAlgo::Fft => cudnnConvolutionFwdAlgo_t::Fft,
568 FwdAlgo::FftTiling => cudnnConvolutionFwdAlgo_t::FftTiling,
569 FwdAlgo::Winograd => cudnnConvolutionFwdAlgo_t::Winograd,
570 FwdAlgo::WinogradNonfused => cudnnConvolutionFwdAlgo_t::WinogradNonfused,
571 }
572 }
573}
574
575pub struct ConvolutionDescriptor {
577 desc: cudnnConvolutionDescriptor_t,
578}
579
580unsafe impl Send for ConvolutionDescriptor {}
581
582impl core::fmt::Debug for ConvolutionDescriptor {
583 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
584 f.debug_struct("ConvolutionDescriptor")
585 .field("desc", &self.desc)
586 .finish_non_exhaustive()
587 }
588}
589
590impl ConvolutionDescriptor {
591 #[allow(clippy::too_many_arguments)]
598 pub fn new_2d(
599 pad_h: i32,
600 pad_w: i32,
601 stride_h: i32,
602 stride_w: i32,
603 dilation_h: i32,
604 dilation_w: i32,
605 mode: ConvMode,
606 compute: DType,
607 ) -> Result<Self> {
608 let cu = cudnn()?;
609 let create = cu.cudnn_create_convolution_descriptor()?;
610 let set = cu.cudnn_set_convolution_2d_descriptor()?;
611 let mut desc: cudnnConvolutionDescriptor_t = core::ptr::null_mut();
612 check(unsafe { create(&mut desc) })?;
613 let this = Self { desc };
614 check(unsafe {
615 set(
616 this.desc,
617 pad_h,
618 pad_w,
619 stride_h,
620 stride_w,
621 dilation_h,
622 dilation_w,
623 mode.raw(),
624 compute.raw(),
625 )
626 })?;
627 Ok(this)
628 }
629
630 pub fn output_dim_2d(
633 &self,
634 input: &TensorDescriptor,
635 filter: &FilterDescriptor,
636 ) -> Result<(i32, i32, i32, i32)> {
637 let cu = cudnn()?;
638 let q = cu.cudnn_get_convolution_2d_forward_output_dim()?;
639 let mut n: core::ffi::c_int = 0;
640 let mut c: core::ffi::c_int = 0;
641 let mut h: core::ffi::c_int = 0;
642 let mut w: core::ffi::c_int = 0;
643 check(unsafe {
644 q(
645 self.desc,
646 input.desc,
647 filter.desc,
648 &mut n,
649 &mut c,
650 &mut h,
651 &mut w,
652 )
653 })?;
654 Ok((n, c, h, w))
655 }
656
657 pub fn set_group_count(&self, group_count: i32) -> Result<()> {
662 let cu = cudnn()?;
663 let f = cu.cudnn_set_convolution_group_count()?;
664 check(unsafe { f(self.desc, group_count) })
665 }
666
667 pub fn group_count(&self) -> Result<i32> {
669 let cu = cudnn()?;
670 let f = cu.cudnn_get_convolution_group_count()?;
671 let mut g: core::ffi::c_int = 0;
672 check(unsafe { f(self.desc, &mut g) })?;
673 Ok(g)
674 }
675
676 pub fn set_math_type(&self, math: MathType) -> Result<()> {
679 let cu = cudnn()?;
680 let f = cu.cudnn_set_convolution_math_type()?;
681 check(unsafe { f(self.desc, math.raw()) })
682 }
683
684 pub fn math_type(&self) -> Result<MathType> {
686 let cu = cudnn()?;
687 let f = cu.cudnn_get_convolution_math_type()?;
688 let mut m = cudnnMathType_t::DefaultMath;
689 check(unsafe { f(self.desc, &mut m) })?;
690 Ok(MathType::from_raw(m))
691 }
692
693 pub fn set_reorder_type(&self, reorder: ReorderType) -> Result<()> {
695 let cu = cudnn()?;
696 let f = cu.cudnn_set_convolution_reorder_type()?;
697 check(unsafe { f(self.desc, reorder.raw()) })
698 }
699
700 pub fn reorder_type(&self) -> Result<ReorderType> {
702 let cu = cudnn()?;
703 let f = cu.cudnn_get_convolution_reorder_type()?;
704 let mut r = cudnnReorderType_t::DefaultReorder;
705 check(unsafe { f(self.desc, &mut r) })?;
706 Ok(ReorderType::from_raw(r))
707 }
708
709 #[inline]
711 pub fn as_raw(&self) -> cudnnConvolutionDescriptor_t {
712 self.desc
713 }
714}
715
716impl Drop for ConvolutionDescriptor {
717 fn drop(&mut self) {
718 if let Ok(c) = cudnn() {
719 if let Ok(cu) = c.cudnn_destroy_convolution_descriptor() {
720 let _ = unsafe { cu(self.desc) };
721 }
722 }
723 }
724}
725
726pub fn convolution_forward_workspace_size(
729 handle: &Handle,
730 x: &TensorDescriptor,
731 w: &FilterDescriptor,
732 conv: &ConvolutionDescriptor,
733 y: &TensorDescriptor,
734 algo: FwdAlgo,
735) -> Result<usize> {
736 let cu = cudnn()?;
737 let q = cu.cudnn_get_convolution_forward_workspace_size()?;
738 let mut size: usize = 0;
739 check(unsafe {
740 q(
741 handle.handle,
742 x.desc,
743 w.desc,
744 conv.desc,
745 y.desc,
746 algo.raw(),
747 &mut size,
748 )
749 })?;
750 Ok(size)
751}
752
753#[allow(clippy::too_many_arguments)]
823pub fn convolution_forward<T: DeviceRepr>(
824 handle: &Handle,
825 alpha: f32,
826 x_desc: &TensorDescriptor,
827 x: &DeviceBuffer<T>,
828 w_desc: &FilterDescriptor,
829 w: &DeviceBuffer<T>,
830 conv: &ConvolutionDescriptor,
831 algo: FwdAlgo,
832 workspace: &mut DeviceBuffer<u8>,
833 beta: f32,
834 y_desc: &TensorDescriptor,
835 y: &mut DeviceBuffer<T>,
836) -> Result<()> {
837 let c = cudnn()?;
838 let cu = c.cudnn_convolution_forward()?;
839 check(unsafe {
840 cu(
841 handle.handle,
842 &alpha as *const f32 as *const core::ffi::c_void,
843 x_desc.desc,
844 x.as_raw().0 as *const core::ffi::c_void,
845 w_desc.desc,
846 w.as_raw().0 as *const core::ffi::c_void,
847 conv.desc,
848 algo.raw(),
849 workspace.as_raw().0 as *mut core::ffi::c_void,
850 workspace.byte_size(),
851 &beta as *const f32 as *const core::ffi::c_void,
852 y_desc.desc,
853 y.as_raw().0 as *mut core::ffi::c_void,
854 )
855 })
856}
857
858#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
862pub enum BwdDataAlgo {
863 #[default]
865 Algo0,
866 Algo1,
868 Fft,
870 FftTiling,
872 Winograd,
874 WinogradNonfused,
876}
877
878impl BwdDataAlgo {
879 fn raw(self) -> cudnnConvolutionBwdDataAlgo_t {
880 match self {
881 Self::Algo0 => cudnnConvolutionBwdDataAlgo_t::Algo0,
882 Self::Algo1 => cudnnConvolutionBwdDataAlgo_t::Algo1,
883 Self::Fft => cudnnConvolutionBwdDataAlgo_t::Fft,
884 Self::FftTiling => cudnnConvolutionBwdDataAlgo_t::FftTiling,
885 Self::Winograd => cudnnConvolutionBwdDataAlgo_t::Winograd,
886 Self::WinogradNonfused => cudnnConvolutionBwdDataAlgo_t::WinogradNonfused,
887 }
888 }
889}
890
891#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
893pub enum BwdFilterAlgo {
894 #[default]
896 Algo0,
897 Algo1,
899 Fft,
901 Algo3,
903 Winograd,
905 WinogradNonfused,
907 FftTiling,
909}
910
911impl BwdFilterAlgo {
912 fn raw(self) -> cudnnConvolutionBwdFilterAlgo_t {
913 match self {
914 Self::Algo0 => cudnnConvolutionBwdFilterAlgo_t::Algo0,
915 Self::Algo1 => cudnnConvolutionBwdFilterAlgo_t::Algo1,
916 Self::Fft => cudnnConvolutionBwdFilterAlgo_t::Fft,
917 Self::Algo3 => cudnnConvolutionBwdFilterAlgo_t::Algo3,
918 Self::Winograd => cudnnConvolutionBwdFilterAlgo_t::Winograd,
919 Self::WinogradNonfused => cudnnConvolutionBwdFilterAlgo_t::WinogradNonfused,
920 Self::FftTiling => cudnnConvolutionBwdFilterAlgo_t::FftTiling,
921 }
922 }
923}
924
925pub fn convolution_backward_data_workspace_size(
928 handle: &Handle,
929 w: &FilterDescriptor,
930 dy: &TensorDescriptor,
931 conv: &ConvolutionDescriptor,
932 dx: &TensorDescriptor,
933 algo: BwdDataAlgo,
934) -> Result<usize> {
935 let cu = cudnn()?;
936 let q = cu.cudnn_get_convolution_backward_data_workspace_size()?;
937 let mut size = 0usize;
938 check(unsafe {
939 q(
940 handle.handle,
941 w.desc,
942 dy.desc,
943 conv.desc,
944 dx.desc,
945 algo.raw(),
946 &mut size,
947 )
948 })?;
949 Ok(size)
950}
951
952pub fn convolution_backward_filter_workspace_size(
955 handle: &Handle,
956 x: &TensorDescriptor,
957 dy: &TensorDescriptor,
958 conv: &ConvolutionDescriptor,
959 dw: &FilterDescriptor,
960 algo: BwdFilterAlgo,
961) -> Result<usize> {
962 let cu = cudnn()?;
963 let q = cu.cudnn_get_convolution_backward_filter_workspace_size()?;
964 let mut size = 0usize;
965 check(unsafe {
966 q(
967 handle.handle,
968 x.desc,
969 dy.desc,
970 conv.desc,
971 dw.desc,
972 algo.raw(),
973 &mut size,
974 )
975 })?;
976 Ok(size)
977}
978
979#[allow(clippy::too_many_arguments)]
981pub fn convolution_backward_data<T: DeviceRepr>(
982 handle: &Handle,
983 alpha: f32,
984 w_desc: &FilterDescriptor,
985 w: &DeviceBuffer<T>,
986 dy_desc: &TensorDescriptor,
987 dy: &DeviceBuffer<T>,
988 conv: &ConvolutionDescriptor,
989 algo: BwdDataAlgo,
990 workspace: &mut DeviceBuffer<u8>,
991 beta: f32,
992 dx_desc: &TensorDescriptor,
993 dx: &mut DeviceBuffer<T>,
994) -> Result<()> {
995 let c = cudnn()?;
996 let cu = c.cudnn_convolution_backward_data()?;
997 check(unsafe {
998 cu(
999 handle.handle,
1000 &alpha as *const f32 as *const core::ffi::c_void,
1001 w_desc.desc,
1002 w.as_raw().0 as *const core::ffi::c_void,
1003 dy_desc.desc,
1004 dy.as_raw().0 as *const core::ffi::c_void,
1005 conv.desc,
1006 algo.raw(),
1007 workspace.as_raw().0 as *mut core::ffi::c_void,
1008 workspace.byte_size(),
1009 &beta as *const f32 as *const core::ffi::c_void,
1010 dx_desc.desc,
1011 dx.as_raw().0 as *mut core::ffi::c_void,
1012 )
1013 })
1014}
1015
1016#[allow(clippy::too_many_arguments)]
1018pub fn convolution_backward_filter<T: DeviceRepr>(
1019 handle: &Handle,
1020 alpha: f32,
1021 x_desc: &TensorDescriptor,
1022 x: &DeviceBuffer<T>,
1023 dy_desc: &TensorDescriptor,
1024 dy: &DeviceBuffer<T>,
1025 conv: &ConvolutionDescriptor,
1026 algo: BwdFilterAlgo,
1027 workspace: &mut DeviceBuffer<u8>,
1028 beta: f32,
1029 dw_desc: &FilterDescriptor,
1030 dw: &mut DeviceBuffer<T>,
1031) -> Result<()> {
1032 let c = cudnn()?;
1033 let cu = c.cudnn_convolution_backward_filter()?;
1034 check(unsafe {
1035 cu(
1036 handle.handle,
1037 &alpha as *const f32 as *const core::ffi::c_void,
1038 x_desc.desc,
1039 x.as_raw().0 as *const core::ffi::c_void,
1040 dy_desc.desc,
1041 dy.as_raw().0 as *const core::ffi::c_void,
1042 conv.desc,
1043 algo.raw(),
1044 workspace.as_raw().0 as *mut core::ffi::c_void,
1045 workspace.byte_size(),
1046 &beta as *const f32 as *const core::ffi::c_void,
1047 dw_desc.desc,
1048 dw.as_raw().0 as *mut core::ffi::c_void,
1049 )
1050 })
1051}
1052
1053pub fn convolution_backward_bias<T: DeviceRepr>(
1055 handle: &Handle,
1056 alpha: f32,
1057 dy_desc: &TensorDescriptor,
1058 dy: &DeviceBuffer<T>,
1059 beta: f32,
1060 db_desc: &TensorDescriptor,
1061 db: &mut DeviceBuffer<T>,
1062) -> Result<()> {
1063 let c = cudnn()?;
1064 let cu = c.cudnn_convolution_backward_bias()?;
1065 check(unsafe {
1066 cu(
1067 handle.handle,
1068 &alpha as *const f32 as *const core::ffi::c_void,
1069 dy_desc.desc,
1070 dy.as_raw().0 as *const core::ffi::c_void,
1071 &beta as *const f32 as *const core::ffi::c_void,
1072 db_desc.desc,
1073 db.as_raw().0 as *mut core::ffi::c_void,
1074 )
1075 })
1076}
1077
1078#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
1082pub enum PoolingMode {
1083 #[default]
1085 Max,
1086 AverageCountIncludePadding,
1088 AverageCountExcludePadding,
1090 MaxDeterministic,
1092}
1093
1094impl PoolingMode {
1095 fn raw(self) -> cudnnPoolingMode_t {
1096 match self {
1097 Self::Max => cudnnPoolingMode_t::Max,
1098 Self::AverageCountIncludePadding => cudnnPoolingMode_t::AverageCountIncludePadding,
1099 Self::AverageCountExcludePadding => cudnnPoolingMode_t::AverageCountExcludePadding,
1100 Self::MaxDeterministic => cudnnPoolingMode_t::MaxDeterministic,
1101 }
1102 }
1103}
1104
1105pub struct PoolingDescriptor {
1107 desc: cudnnPoolingDescriptor_t,
1108}
1109
1110unsafe impl Send for PoolingDescriptor {}
1111
1112impl core::fmt::Debug for PoolingDescriptor {
1113 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1114 f.debug_struct("PoolingDescriptor")
1115 .field("desc", &self.desc)
1116 .finish_non_exhaustive()
1117 }
1118}
1119
1120impl PoolingDescriptor {
1121 #[allow(clippy::too_many_arguments)]
1124 pub fn new_2d(
1125 mode: PoolingMode,
1126 window_h: i32,
1127 window_w: i32,
1128 pad_h: i32,
1129 pad_w: i32,
1130 stride_h: i32,
1131 stride_w: i32,
1132 ) -> Result<Self> {
1133 let cu = cudnn()?;
1134 let create = cu.cudnn_create_pooling_descriptor()?;
1135 let set = cu.cudnn_set_pooling_2d_descriptor()?;
1136 let mut desc: cudnnPoolingDescriptor_t = core::ptr::null_mut();
1137 check(unsafe { create(&mut desc) })?;
1138 let this = Self { desc };
1139 check(unsafe {
1140 set(
1141 this.desc,
1142 mode.raw(),
1143 cudnnNanPropagation_t::PropagateNan,
1144 window_h,
1145 window_w,
1146 pad_h,
1147 pad_w,
1148 stride_h,
1149 stride_w,
1150 )
1151 })?;
1152 Ok(this)
1153 }
1154
1155 #[inline]
1157 pub fn as_raw(&self) -> cudnnPoolingDescriptor_t {
1158 self.desc
1159 }
1160}
1161
1162impl Drop for PoolingDescriptor {
1163 fn drop(&mut self) {
1164 if let Ok(c) = cudnn() {
1165 if let Ok(cu) = c.cudnn_destroy_pooling_descriptor() {
1166 let _ = unsafe { cu(self.desc) };
1167 }
1168 }
1169 }
1170}
1171
1172#[allow(clippy::too_many_arguments)]
1206pub fn pooling_forward<T: DeviceRepr>(
1207 handle: &Handle,
1208 pool: &PoolingDescriptor,
1209 alpha: f32,
1210 x_desc: &TensorDescriptor,
1211 x: &DeviceBuffer<T>,
1212 beta: f32,
1213 y_desc: &TensorDescriptor,
1214 y: &mut DeviceBuffer<T>,
1215) -> Result<()> {
1216 let c = cudnn()?;
1217 let cu = c.cudnn_pooling_forward()?;
1218 check(unsafe {
1219 cu(
1220 handle.handle,
1221 pool.desc,
1222 &alpha as *const f32 as *const core::ffi::c_void,
1223 x_desc.desc,
1224 x.as_raw().0 as *const core::ffi::c_void,
1225 &beta as *const f32 as *const core::ffi::c_void,
1226 y_desc.desc,
1227 y.as_raw().0 as *mut core::ffi::c_void,
1228 )
1229 })
1230}
1231
1232#[allow(clippy::too_many_arguments)]
1234pub fn pooling_backward<T: DeviceRepr>(
1235 handle: &Handle,
1236 pool: &PoolingDescriptor,
1237 alpha: f32,
1238 y_desc: &TensorDescriptor,
1239 y: &DeviceBuffer<T>,
1240 dy_desc: &TensorDescriptor,
1241 dy: &DeviceBuffer<T>,
1242 x_desc: &TensorDescriptor,
1243 x: &DeviceBuffer<T>,
1244 beta: f32,
1245 dx_desc: &TensorDescriptor,
1246 dx: &mut DeviceBuffer<T>,
1247) -> Result<()> {
1248 let c = cudnn()?;
1249 let cu = c.cudnn_pooling_backward()?;
1250 check(unsafe {
1251 cu(
1252 handle.handle,
1253 pool.desc,
1254 &alpha as *const f32 as *const core::ffi::c_void,
1255 y_desc.desc,
1256 y.as_raw().0 as *const core::ffi::c_void,
1257 dy_desc.desc,
1258 dy.as_raw().0 as *const core::ffi::c_void,
1259 x_desc.desc,
1260 x.as_raw().0 as *const core::ffi::c_void,
1261 &beta as *const f32 as *const core::ffi::c_void,
1262 dx_desc.desc,
1263 dx.as_raw().0 as *mut core::ffi::c_void,
1264 )
1265 })
1266}
1267
1268#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
1272pub enum SoftmaxAlgo {
1273 Fast,
1275 #[default]
1277 Accurate,
1278 Log,
1280}
1281
1282impl SoftmaxAlgo {
1283 fn raw(self) -> cudnnSoftmaxAlgorithm_t {
1284 match self {
1285 Self::Fast => cudnnSoftmaxAlgorithm_t::Fast,
1286 Self::Accurate => cudnnSoftmaxAlgorithm_t::Accurate,
1287 Self::Log => cudnnSoftmaxAlgorithm_t::Log,
1288 }
1289 }
1290}
1291
1292#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
1294pub enum SoftmaxMode {
1295 Instance,
1297 #[default]
1299 Channel,
1300}
1301
1302impl SoftmaxMode {
1303 fn raw(self) -> cudnnSoftmaxMode_t {
1304 match self {
1305 Self::Instance => cudnnSoftmaxMode_t::Instance,
1306 Self::Channel => cudnnSoftmaxMode_t::Channel,
1307 }
1308 }
1309}
1310
1311#[allow(clippy::too_many_arguments)]
1313pub fn softmax_forward<T: DeviceRepr>(
1314 handle: &Handle,
1315 algo: SoftmaxAlgo,
1316 mode: SoftmaxMode,
1317 alpha: f32,
1318 x_desc: &TensorDescriptor,
1319 x: &DeviceBuffer<T>,
1320 beta: f32,
1321 y_desc: &TensorDescriptor,
1322 y: &mut DeviceBuffer<T>,
1323) -> Result<()> {
1324 let c = cudnn()?;
1325 let cu = c.cudnn_softmax_forward()?;
1326 check(unsafe {
1327 cu(
1328 handle.handle,
1329 algo.raw(),
1330 mode.raw(),
1331 &alpha as *const f32 as *const core::ffi::c_void,
1332 x_desc.desc,
1333 x.as_raw().0 as *const core::ffi::c_void,
1334 &beta as *const f32 as *const core::ffi::c_void,
1335 y_desc.desc,
1336 y.as_raw().0 as *mut core::ffi::c_void,
1337 )
1338 })
1339}
1340
1341#[allow(clippy::too_many_arguments)]
1343pub fn softmax_backward<T: DeviceRepr>(
1344 handle: &Handle,
1345 algo: SoftmaxAlgo,
1346 mode: SoftmaxMode,
1347 alpha: f32,
1348 y_desc: &TensorDescriptor,
1349 y: &DeviceBuffer<T>,
1350 dy_desc: &TensorDescriptor,
1351 dy: &DeviceBuffer<T>,
1352 beta: f32,
1353 dx_desc: &TensorDescriptor,
1354 dx: &mut DeviceBuffer<T>,
1355) -> Result<()> {
1356 let c = cudnn()?;
1357 let cu = c.cudnn_softmax_backward()?;
1358 check(unsafe {
1359 cu(
1360 handle.handle,
1361 algo.raw(),
1362 mode.raw(),
1363 &alpha as *const f32 as *const core::ffi::c_void,
1364 y_desc.desc,
1365 y.as_raw().0 as *const core::ffi::c_void,
1366 dy_desc.desc,
1367 dy.as_raw().0 as *const core::ffi::c_void,
1368 &beta as *const f32 as *const core::ffi::c_void,
1369 dx_desc.desc,
1370 dx.as_raw().0 as *mut core::ffi::c_void,
1371 )
1372 })
1373}
1374
1375#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
1379pub enum BatchNormMode {
1380 PerActivation,
1382 #[default]
1384 Spatial,
1385 SpatialPersistent,
1387}
1388
1389impl BatchNormMode {
1390 fn raw(self) -> cudnnBatchNormMode_t {
1391 match self {
1392 Self::PerActivation => cudnnBatchNormMode_t::PerActivation,
1393 Self::Spatial => cudnnBatchNormMode_t::Spatial,
1394 Self::SpatialPersistent => cudnnBatchNormMode_t::SpatialPersistent,
1395 }
1396 }
1397}
1398
1399#[allow(clippy::too_many_arguments)]
1402pub fn batch_normalization_forward_training<T: DeviceRepr>(
1403 handle: &Handle,
1404 mode: BatchNormMode,
1405 alpha: f32,
1406 beta: f32,
1407 x_desc: &TensorDescriptor,
1408 x: &DeviceBuffer<T>,
1409 y_desc: &TensorDescriptor,
1410 y: &mut DeviceBuffer<T>,
1411 bn_smbv_desc: &TensorDescriptor,
1412 bn_scale: &DeviceBuffer<T>,
1413 bn_bias: &DeviceBuffer<T>,
1414 exponential_avg_factor: f64,
1415 running_mean: &mut DeviceBuffer<T>,
1416 running_variance: &mut DeviceBuffer<T>,
1417 epsilon: f64,
1418 saved_mean: &mut DeviceBuffer<T>,
1419 saved_inv_variance: &mut DeviceBuffer<T>,
1420) -> Result<()> {
1421 let c = cudnn()?;
1422 let cu = c.cudnn_batch_normalization_forward_training()?;
1423 check(unsafe {
1424 cu(
1425 handle.handle,
1426 mode.raw(),
1427 &alpha as *const f32 as *const core::ffi::c_void,
1428 &beta as *const f32 as *const core::ffi::c_void,
1429 x_desc.desc,
1430 x.as_raw().0 as *const core::ffi::c_void,
1431 y_desc.desc,
1432 y.as_raw().0 as *mut core::ffi::c_void,
1433 bn_smbv_desc.desc,
1434 bn_scale.as_raw().0 as *const core::ffi::c_void,
1435 bn_bias.as_raw().0 as *const core::ffi::c_void,
1436 exponential_avg_factor,
1437 running_mean.as_raw().0 as *mut core::ffi::c_void,
1438 running_variance.as_raw().0 as *mut core::ffi::c_void,
1439 epsilon,
1440 saved_mean.as_raw().0 as *mut core::ffi::c_void,
1441 saved_inv_variance.as_raw().0 as *mut core::ffi::c_void,
1442 )
1443 })
1444}
1445
1446#[allow(clippy::too_many_arguments)]
1448pub fn batch_normalization_backward<T: DeviceRepr>(
1449 handle: &Handle,
1450 mode: BatchNormMode,
1451 alpha_data_diff: f32,
1452 beta_data_diff: f32,
1453 alpha_param_diff: f32,
1454 beta_param_diff: f32,
1455 x_desc: &TensorDescriptor,
1456 x: &DeviceBuffer<T>,
1457 dy_desc: &TensorDescriptor,
1458 dy: &DeviceBuffer<T>,
1459 dx_desc: &TensorDescriptor,
1460 dx: &mut DeviceBuffer<T>,
1461 bn_scale_bias_diff_desc: &TensorDescriptor,
1462 bn_scale: &DeviceBuffer<T>,
1463 d_bn_scale: &mut DeviceBuffer<T>,
1464 d_bn_bias: &mut DeviceBuffer<T>,
1465 epsilon: f64,
1466 saved_mean: &DeviceBuffer<T>,
1467 saved_inv_variance: &DeviceBuffer<T>,
1468) -> Result<()> {
1469 let c = cudnn()?;
1470 let cu = c.cudnn_batch_normalization_backward()?;
1471 check(unsafe {
1472 cu(
1473 handle.handle,
1474 mode.raw(),
1475 &alpha_data_diff as *const f32 as *const core::ffi::c_void,
1476 &beta_data_diff as *const f32 as *const core::ffi::c_void,
1477 &alpha_param_diff as *const f32 as *const core::ffi::c_void,
1478 &beta_param_diff as *const f32 as *const core::ffi::c_void,
1479 x_desc.desc,
1480 x.as_raw().0 as *const core::ffi::c_void,
1481 dy_desc.desc,
1482 dy.as_raw().0 as *const core::ffi::c_void,
1483 dx_desc.desc,
1484 dx.as_raw().0 as *mut core::ffi::c_void,
1485 bn_scale_bias_diff_desc.desc,
1486 bn_scale.as_raw().0 as *const core::ffi::c_void,
1487 d_bn_scale.as_raw().0 as *mut core::ffi::c_void,
1488 d_bn_bias.as_raw().0 as *mut core::ffi::c_void,
1489 epsilon,
1490 saved_mean.as_raw().0 as *const core::ffi::c_void,
1491 saved_inv_variance.as_raw().0 as *const core::ffi::c_void,
1492 )
1493 })
1494}
1495
1496#[allow(clippy::too_many_arguments)]
1537pub fn batch_normalization_forward_inference<T: DeviceRepr>(
1538 handle: &Handle,
1539 mode: BatchNormMode,
1540 alpha: f32,
1541 beta: f32,
1542 x_desc: &TensorDescriptor,
1543 x: &DeviceBuffer<T>,
1544 y_desc: &TensorDescriptor,
1545 y: &mut DeviceBuffer<T>,
1546 bn_smbv_desc: &TensorDescriptor,
1547 bn_scale: &DeviceBuffer<T>,
1548 bn_bias: &DeviceBuffer<T>,
1549 estimated_mean: &DeviceBuffer<T>,
1550 estimated_var: &DeviceBuffer<T>,
1551 epsilon: f64,
1552) -> Result<()> {
1553 let c = cudnn()?;
1554 let cu = c.cudnn_batch_normalization_forward_inference()?;
1555 check(unsafe {
1556 cu(
1557 handle.handle,
1558 mode.raw(),
1559 &alpha as *const f32 as *const core::ffi::c_void,
1560 &beta as *const f32 as *const core::ffi::c_void,
1561 x_desc.desc,
1562 x.as_raw().0 as *const core::ffi::c_void,
1563 y_desc.desc,
1564 y.as_raw().0 as *mut core::ffi::c_void,
1565 bn_smbv_desc.desc,
1566 bn_scale.as_raw().0 as *const core::ffi::c_void,
1567 bn_bias.as_raw().0 as *const core::ffi::c_void,
1568 estimated_mean.as_raw().0 as *const core::ffi::c_void,
1569 estimated_var.as_raw().0 as *const core::ffi::c_void,
1570 epsilon,
1571 )
1572 })
1573}
1574
1575pub struct DropoutDescriptor {
1579 desc: cudnnDropoutDescriptor_t,
1580}
1581
1582unsafe impl Send for DropoutDescriptor {}
1583
1584impl core::fmt::Debug for DropoutDescriptor {
1585 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1586 f.debug_struct("DropoutDescriptor")
1587 .field("desc", &self.desc)
1588 .finish_non_exhaustive()
1589 }
1590}
1591
1592impl DropoutDescriptor {
1593 pub fn new(
1598 handle: &Handle,
1599 dropout: f32,
1600 states: &mut DeviceBuffer<u8>,
1601 seed: u64,
1602 ) -> Result<Self> {
1603 let cu = cudnn()?;
1604 let create = cu.cudnn_create_dropout_descriptor()?;
1605 let set = cu.cudnn_set_dropout_descriptor()?;
1606 let mut desc: cudnnDropoutDescriptor_t = core::ptr::null_mut();
1607 check(unsafe { create(&mut desc) })?;
1608 let this = Self { desc };
1609 check(unsafe {
1610 set(
1611 this.desc,
1612 handle.handle,
1613 dropout,
1614 states.as_raw().0 as *mut core::ffi::c_void,
1615 states.byte_size(),
1616 seed,
1617 )
1618 })?;
1619 Ok(this)
1620 }
1621
1622 #[inline]
1624 pub fn as_raw(&self) -> cudnnDropoutDescriptor_t {
1625 self.desc
1626 }
1627}
1628
1629impl Drop for DropoutDescriptor {
1630 fn drop(&mut self) {
1631 if let Ok(c) = cudnn() {
1632 if let Ok(cu) = c.cudnn_destroy_dropout_descriptor() {
1633 let _ = unsafe { cu(self.desc) };
1634 }
1635 }
1636 }
1637}
1638
1639pub fn dropout_states_size(handle: &Handle) -> Result<usize> {
1641 let c = cudnn()?;
1642 let cu = c.cudnn_dropout_get_states_size()?;
1643 let mut size = 0usize;
1644 check(unsafe { cu(handle.handle, &mut size) })?;
1645 Ok(size)
1646}
1647
1648pub fn dropout_reserve_size(x: &TensorDescriptor) -> Result<usize> {
1650 let c = cudnn()?;
1651 let cu = c.cudnn_dropout_get_reserve_space_size()?;
1652 let mut size = 0usize;
1653 check(unsafe { cu(x.desc, &mut size) })?;
1654 Ok(size)
1655}
1656
1657#[allow(clippy::too_many_arguments)]
1660pub fn dropout_forward<T: DeviceRepr>(
1661 handle: &Handle,
1662 dropout: &DropoutDescriptor,
1663 x_desc: &TensorDescriptor,
1664 x: &DeviceBuffer<T>,
1665 y_desc: &TensorDescriptor,
1666 y: &mut DeviceBuffer<T>,
1667 reserve: &mut DeviceBuffer<u8>,
1668) -> Result<()> {
1669 let c = cudnn()?;
1670 let cu = c.cudnn_dropout_forward()?;
1671 check(unsafe {
1672 cu(
1673 handle.handle,
1674 dropout.desc,
1675 x_desc.desc,
1676 x.as_raw().0 as *const core::ffi::c_void,
1677 y_desc.desc,
1678 y.as_raw().0 as *mut core::ffi::c_void,
1679 reserve.as_raw().0 as *mut core::ffi::c_void,
1680 reserve.byte_size(),
1681 )
1682 })
1683}
1684
1685#[allow(clippy::too_many_arguments)]
1689pub fn dropout_backward<T: DeviceRepr>(
1690 handle: &Handle,
1691 dropout: &DropoutDescriptor,
1692 dy_desc: &TensorDescriptor,
1693 dy: &DeviceBuffer<T>,
1694 dx_desc: &TensorDescriptor,
1695 dx: &mut DeviceBuffer<T>,
1696 reserve: &mut DeviceBuffer<u8>,
1697) -> Result<()> {
1698 let c = cudnn()?;
1699 let cu = c.cudnn_dropout_backward()?;
1700 check(unsafe {
1701 cu(
1702 handle.handle,
1703 dropout.desc,
1704 dy_desc.desc,
1705 dy.as_raw().0 as *const core::ffi::c_void,
1706 dx_desc.desc,
1707 dx.as_raw().0 as *mut core::ffi::c_void,
1708 reserve.as_raw().0 as *mut core::ffi::c_void,
1709 reserve.byte_size(),
1710 )
1711 })
1712}
1713
1714pub struct LrnDescriptor {
1718 desc: cudnnLRNDescriptor_t,
1719}
1720
1721unsafe impl Send for LrnDescriptor {}
1722
1723impl core::fmt::Debug for LrnDescriptor {
1724 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1725 f.debug_struct("LrnDescriptor")
1726 .field("desc", &self.desc)
1727 .finish_non_exhaustive()
1728 }
1729}
1730
1731impl LrnDescriptor {
1732 pub fn new(n: i32, alpha: f64, beta: f64, k: f64) -> Result<Self> {
1735 let cu = cudnn()?;
1736 let create = cu.cudnn_create_lrn_descriptor()?;
1737 let set = cu.cudnn_set_lrn_descriptor()?;
1738 let mut desc: cudnnLRNDescriptor_t = core::ptr::null_mut();
1739 check(unsafe { create(&mut desc) })?;
1740 let this = Self { desc };
1741 check(unsafe { set(this.desc, n, alpha, beta, k) })?;
1742 Ok(this)
1743 }
1744
1745 #[inline]
1747 pub fn as_raw(&self) -> cudnnLRNDescriptor_t {
1748 self.desc
1749 }
1750}
1751
1752impl Drop for LrnDescriptor {
1753 fn drop(&mut self) {
1754 if let Ok(c) = cudnn() {
1755 if let Ok(cu) = c.cudnn_destroy_lrn_descriptor() {
1756 let _ = unsafe { cu(self.desc) };
1757 }
1758 }
1759 }
1760}
1761
1762#[derive(Copy, Clone, Debug, Eq, PartialEq)]
1766pub enum OpTensorOp {
1767 Add,
1769 Mul,
1771 Min,
1773 Max,
1775 Sqrt,
1777 Not,
1779}
1780
1781impl OpTensorOp {
1782 fn raw(self) -> cudnnOpTensorOp_t {
1783 match self {
1784 Self::Add => cudnnOpTensorOp_t::Add,
1785 Self::Mul => cudnnOpTensorOp_t::Mul,
1786 Self::Min => cudnnOpTensorOp_t::Min,
1787 Self::Max => cudnnOpTensorOp_t::Max,
1788 Self::Sqrt => cudnnOpTensorOp_t::Sqrt,
1789 Self::Not => cudnnOpTensorOp_t::Not,
1790 }
1791 }
1792}
1793
1794pub struct OpTensorDescriptor {
1796 desc: cudnnOpTensorDescriptor_t,
1797}
1798
1799unsafe impl Send for OpTensorDescriptor {}
1800
1801impl core::fmt::Debug for OpTensorDescriptor {
1802 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1803 f.debug_struct("OpTensorDescriptor")
1804 .field("desc", &self.desc)
1805 .finish_non_exhaustive()
1806 }
1807}
1808
1809impl OpTensorDescriptor {
1810 pub fn new(op: OpTensorOp, compute: DType) -> Result<Self> {
1813 let cu = cudnn()?;
1814 let create = cu.cudnn_create_op_tensor_descriptor()?;
1815 let set = cu.cudnn_set_op_tensor_descriptor()?;
1816 let mut desc: cudnnOpTensorDescriptor_t = core::ptr::null_mut();
1817 check(unsafe { create(&mut desc) })?;
1818 let this = Self { desc };
1819 check(unsafe {
1820 set(
1821 this.desc,
1822 op.raw(),
1823 compute.raw(),
1824 cudnnNanPropagation_t::PropagateNan,
1825 )
1826 })?;
1827 Ok(this)
1828 }
1829
1830 #[inline]
1832 pub fn as_raw(&self) -> cudnnOpTensorDescriptor_t {
1833 self.desc
1834 }
1835}
1836
1837impl Drop for OpTensorDescriptor {
1838 fn drop(&mut self) {
1839 if let Ok(c) = cudnn() {
1840 if let Ok(cu) = c.cudnn_destroy_op_tensor_descriptor() {
1841 let _ = unsafe { cu(self.desc) };
1842 }
1843 }
1844 }
1845}
1846
1847#[allow(clippy::too_many_arguments)]
1849pub fn op_tensor<T: DeviceRepr>(
1850 handle: &Handle,
1851 op: &OpTensorDescriptor,
1852 alpha1: f32,
1853 a_desc: &TensorDescriptor,
1854 a: &DeviceBuffer<T>,
1855 alpha2: f32,
1856 b_desc: &TensorDescriptor,
1857 b: &DeviceBuffer<T>,
1858 beta: f32,
1859 c_desc: &TensorDescriptor,
1860 c: &mut DeviceBuffer<T>,
1861) -> Result<()> {
1862 let cu_crate = cudnn()?;
1863 let cu = cu_crate.cudnn_op_tensor()?;
1864 check(unsafe {
1865 cu(
1866 handle.handle,
1867 op.desc,
1868 &alpha1 as *const f32 as *const core::ffi::c_void,
1869 a_desc.desc,
1870 a.as_raw().0 as *const core::ffi::c_void,
1871 &alpha2 as *const f32 as *const core::ffi::c_void,
1872 b_desc.desc,
1873 b.as_raw().0 as *const core::ffi::c_void,
1874 &beta as *const f32 as *const core::ffi::c_void,
1875 c_desc.desc,
1876 c.as_raw().0 as *mut core::ffi::c_void,
1877 )
1878 })
1879}
1880
1881#[derive(Copy, Clone, Debug, Eq, PartialEq)]
1883pub enum ReduceOp {
1884 Add,
1886 Mul,
1888 Min,
1890 Max,
1892 AbsMax,
1894 Avg,
1896 Norm1,
1898 Norm2,
1900 MulNoZeros,
1902}
1903
1904impl ReduceOp {
1905 fn raw(self) -> cudnnReduceTensorOp_t {
1906 match self {
1907 Self::Add => cudnnReduceTensorOp_t::Add,
1908 Self::Mul => cudnnReduceTensorOp_t::Mul,
1909 Self::Min => cudnnReduceTensorOp_t::Min,
1910 Self::Max => cudnnReduceTensorOp_t::Max,
1911 Self::AbsMax => cudnnReduceTensorOp_t::Amax,
1912 Self::Avg => cudnnReduceTensorOp_t::Avg,
1913 Self::Norm1 => cudnnReduceTensorOp_t::Norm1,
1914 Self::Norm2 => cudnnReduceTensorOp_t::Norm2,
1915 Self::MulNoZeros => cudnnReduceTensorOp_t::MulNoZeros,
1916 }
1917 }
1918}
1919
1920pub struct ReduceTensorDescriptor {
1922 desc: cudnnReduceTensorDescriptor_t,
1923}
1924
1925unsafe impl Send for ReduceTensorDescriptor {}
1926
1927impl core::fmt::Debug for ReduceTensorDescriptor {
1928 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1929 f.debug_struct("ReduceTensorDescriptor")
1930 .field("desc", &self.desc)
1931 .finish_non_exhaustive()
1932 }
1933}
1934
1935impl ReduceTensorDescriptor {
1936 pub fn new(op: ReduceOp, compute: DType) -> Result<Self> {
1941 let cu = cudnn()?;
1942 let create = cu.cudnn_create_reduce_tensor_descriptor()?;
1943 let set = cu.cudnn_set_reduce_tensor_descriptor()?;
1944 let mut desc: cudnnReduceTensorDescriptor_t = core::ptr::null_mut();
1945 check(unsafe { create(&mut desc) })?;
1946 let this = Self { desc };
1947 check(unsafe {
1948 set(
1949 this.desc,
1950 op.raw(),
1951 compute.raw(),
1952 cudnnNanPropagation_t::PropagateNan,
1953 cudnnReduceTensorIndices_t::NoIndices,
1954 cudnnIndicesType_t::U32,
1955 )
1956 })?;
1957 Ok(this)
1958 }
1959
1960 pub fn workspace_size(
1962 &self,
1963 handle: &Handle,
1964 a: &TensorDescriptor,
1965 c: &TensorDescriptor,
1966 ) -> Result<usize> {
1967 let cu = cudnn()?;
1968 let q = cu.cudnn_get_reduction_workspace_size()?;
1969 let mut size = 0usize;
1970 check(unsafe { q(handle.handle, self.desc, a.desc, c.desc, &mut size) })?;
1971 Ok(size)
1972 }
1973
1974 #[inline]
1976 pub fn as_raw(&self) -> cudnnReduceTensorDescriptor_t {
1977 self.desc
1978 }
1979}
1980
1981impl Drop for ReduceTensorDescriptor {
1982 fn drop(&mut self) {
1983 if let Ok(c) = cudnn() {
1984 if let Ok(cu) = c.cudnn_destroy_reduce_tensor_descriptor() {
1985 let _ = unsafe { cu(self.desc) };
1986 }
1987 }
1988 }
1989}
1990
1991#[allow(clippy::too_many_arguments)]
1994pub fn reduce_tensor<T: DeviceRepr>(
1995 handle: &Handle,
1996 reducer: &ReduceTensorDescriptor,
1997 workspace: &mut DeviceBuffer<u8>,
1998 alpha: f32,
1999 a_desc: &TensorDescriptor,
2000 a: &DeviceBuffer<T>,
2001 beta: f32,
2002 c_desc: &TensorDescriptor,
2003 c: &mut DeviceBuffer<T>,
2004) -> Result<()> {
2005 let cu_crate = cudnn()?;
2006 let cu = cu_crate.cudnn_reduce_tensor()?;
2007 check(unsafe {
2008 cu(
2009 handle.handle,
2010 reducer.desc,
2011 core::ptr::null_mut(),
2012 0,
2013 workspace.as_raw().0 as *mut core::ffi::c_void,
2014 workspace.byte_size(),
2015 &alpha as *const f32 as *const core::ffi::c_void,
2016 a_desc.desc,
2017 a.as_raw().0 as *const core::ffi::c_void,
2018 &beta as *const f32 as *const core::ffi::c_void,
2019 c_desc.desc,
2020 c.as_raw().0 as *mut core::ffi::c_void,
2021 )
2022 })
2023}
2024
2025pub fn add_tensor<T: DeviceRepr>(
2028 handle: &Handle,
2029 alpha: f32,
2030 a_desc: &TensorDescriptor,
2031 a: &DeviceBuffer<T>,
2032 beta: f32,
2033 c_desc: &TensorDescriptor,
2034 c: &mut DeviceBuffer<T>,
2035) -> Result<()> {
2036 let cu_crate = cudnn()?;
2037 let cu = cu_crate.cudnn_add_tensor()?;
2038 check(unsafe {
2039 cu(
2040 handle.handle,
2041 &alpha as *const f32 as *const core::ffi::c_void,
2042 a_desc.desc,
2043 a.as_raw().0 as *const core::ffi::c_void,
2044 &beta as *const f32 as *const core::ffi::c_void,
2045 c_desc.desc,
2046 c.as_raw().0 as *mut core::ffi::c_void,
2047 )
2048 })
2049}
2050
2051pub struct BackendDescriptor {
2059 desc: cudnnBackendDescriptor_t,
2060 finalized: bool,
2061}
2062
2063unsafe impl Send for BackendDescriptor {}
2064
2065impl core::fmt::Debug for BackendDescriptor {
2066 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2067 f.debug_struct("BackendDescriptor")
2068 .field("desc", &self.desc)
2069 .field("finalized", &self.finalized)
2070 .finish()
2071 }
2072}
2073
2074impl BackendDescriptor {
2075 pub fn new(kind: cudnnBackendDescriptorType_t) -> Result<Self> {
2080 let cu = cudnn()?;
2081 let create = cu.cudnn_backend_create_descriptor()?;
2082 let init = cu.cudnn_backend_initialize()?;
2083 let mut desc: cudnnBackendDescriptor_t = core::ptr::null_mut();
2084 check(unsafe { create(kind, &mut desc) })?;
2085 let this = Self {
2086 desc,
2087 finalized: false,
2088 };
2089 check(unsafe { init(this.desc) })?;
2090 Ok(this)
2091 }
2092
2093 pub unsafe fn set_attribute_raw(
2100 &self,
2101 name: cudnnBackendAttributeName_t,
2102 ty: cudnnBackendAttributeType_t,
2103 element_count: i64,
2104 array_of_elements: *const core::ffi::c_void,
2105 ) -> Result<()> { unsafe {
2106 let cu = cudnn()?;
2107 let f = cu.cudnn_backend_set_attribute()?;
2108 check(f(self.desc, name, ty, element_count, array_of_elements))
2109 }}
2110
2111 pub fn finalize(&mut self) -> Result<()> {
2114 if self.finalized {
2115 return Ok(());
2116 }
2117 let cu = cudnn()?;
2118 let f = cu.cudnn_backend_finalize()?;
2119 check(unsafe { f(self.desc) })?;
2120 self.finalized = true;
2121 Ok(())
2122 }
2123
2124 pub fn execute(&self, handle: &Handle, variant_pack: &BackendDescriptor) -> Result<()> {
2127 let cu = cudnn()?;
2128 let f = cu.cudnn_backend_execute()?;
2129 check(unsafe { f(handle.handle, self.desc, variant_pack.desc) })
2130 }
2131
2132 #[inline]
2134 pub fn as_raw(&self) -> cudnnBackendDescriptor_t {
2135 self.desc
2136 }
2137}
2138
2139impl Drop for BackendDescriptor {
2140 fn drop(&mut self) {
2141 if let Ok(c) = cudnn() {
2142 if let Ok(cu) = c.cudnn_backend_destroy_descriptor() {
2143 let _ = unsafe { cu(self.desc) };
2144 }
2145 }
2146 }
2147}
2148
2149pub use baracuda_cudnn_sys::{
2152 cudnnBackendAttributeName_t as BackendAttrName,
2153 cudnnBackendAttributeType_t as BackendAttrType,
2154 cudnnBackendDescriptorType_t as BackendDescType,
2155};
2156
2157use baracuda_cudnn_sys::cudnnCTCLossDescriptor_t;
2160
2161pub struct CtcLossDescriptor {
2163 desc: cudnnCTCLossDescriptor_t,
2164}
2165
2166unsafe impl Send for CtcLossDescriptor {}
2167
2168impl core::fmt::Debug for CtcLossDescriptor {
2169 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2170 f.debug_struct("CtcLossDescriptor")
2171 .field("desc", &self.desc)
2172 .finish_non_exhaustive()
2173 }
2174}
2175
2176impl CtcLossDescriptor {
2177 pub fn new(compute: DType) -> Result<Self> {
2179 let cu = cudnn()?;
2180 let create = cu.cudnn_create_ctc_loss_descriptor()?;
2181 let set = cu.cudnn_set_ctc_loss_descriptor()?;
2182 let mut desc: cudnnCTCLossDescriptor_t = core::ptr::null_mut();
2183 check(unsafe { create(&mut desc) })?;
2184 let this = Self { desc };
2185 check(unsafe { set(this.desc, compute.raw()) })?;
2186 Ok(this)
2187 }
2188
2189 #[inline]
2191 pub fn as_raw(&self) -> cudnnCTCLossDescriptor_t {
2192 self.desc
2193 }
2194}
2195
2196impl Drop for CtcLossDescriptor {
2197 fn drop(&mut self) {
2198 if let Ok(c) = cudnn() {
2199 if let Ok(cu) = c.cudnn_destroy_ctc_loss_descriptor() {
2200 let _ = unsafe { cu(self.desc) };
2201 }
2202 }
2203 }
2204}
2205
2206#[allow(clippy::too_many_arguments)]
2208pub fn ctc_loss_workspace_size(
2209 handle: &Handle,
2210 probs: &TensorDescriptor,
2211 gradients: &TensorDescriptor,
2212 labels: &[i32],
2213 label_lengths: &[i32],
2214 input_lengths: &[i32],
2215 algo: i32,
2216 desc: &CtcLossDescriptor,
2217) -> Result<usize> {
2218 let cu = cudnn()?;
2219 let q = cu.cudnn_get_ctc_loss_workspace_size()?;
2220 let mut size = 0usize;
2221 check(unsafe {
2222 q(
2223 handle.handle,
2224 probs.desc,
2225 gradients.desc,
2226 labels.as_ptr(),
2227 label_lengths.as_ptr(),
2228 input_lengths.as_ptr(),
2229 algo,
2230 desc.desc,
2231 &mut size,
2232 )
2233 })?;
2234 Ok(size)
2235}
2236
2237#[allow(clippy::too_many_arguments)]
2239pub fn ctc_loss<T: DeviceRepr>(
2240 handle: &Handle,
2241 probs_desc: &TensorDescriptor,
2242 probs: &DeviceBuffer<T>,
2243 labels: &[i32],
2244 label_lengths: &[i32],
2245 input_lengths: &[i32],
2246 costs: &mut DeviceBuffer<T>,
2247 gradients_desc: &TensorDescriptor,
2248 gradients: &mut DeviceBuffer<T>,
2249 algo: i32,
2250 desc: &CtcLossDescriptor,
2251 workspace: &mut DeviceBuffer<u8>,
2252) -> Result<()> {
2253 let c = cudnn()?;
2254 let cu = c.cudnn_ctc_loss()?;
2255 check(unsafe {
2256 cu(
2257 handle.handle,
2258 probs_desc.desc,
2259 probs.as_raw().0 as *const core::ffi::c_void,
2260 labels.as_ptr(),
2261 label_lengths.as_ptr(),
2262 input_lengths.as_ptr(),
2263 costs.as_raw().0 as *mut core::ffi::c_void,
2264 gradients_desc.desc,
2265 gradients.as_raw().0 as *mut core::ffi::c_void,
2266 algo,
2267 desc.desc,
2268 workspace.as_raw().0 as *mut core::ffi::c_void,
2269 workspace.byte_size(),
2270 )
2271 })
2272}
2273
2274use baracuda_cudnn_sys::cudnnSpatialTransformerDescriptor_t;
2277
2278pub struct SpatialTransformerDescriptor {
2280 desc: cudnnSpatialTransformerDescriptor_t,
2281}
2282
2283unsafe impl Send for SpatialTransformerDescriptor {}
2284
2285impl core::fmt::Debug for SpatialTransformerDescriptor {
2286 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2287 f.debug_struct("SpatialTransformerDescriptor")
2288 .field("desc", &self.desc)
2289 .finish_non_exhaustive()
2290 }
2291}
2292
2293impl SpatialTransformerDescriptor {
2294 pub fn new(sampler_type: i32, dtype: DType, dims: &[i32]) -> Result<Self> {
2296 let cu = cudnn()?;
2297 let create = cu.cudnn_create_spatial_transformer_descriptor()?;
2298 let set = cu.cudnn_set_spatial_transformer_nd_descriptor()?;
2299 let mut desc: cudnnSpatialTransformerDescriptor_t = core::ptr::null_mut();
2300 check(unsafe { create(&mut desc) })?;
2301 let this = Self { desc };
2302 check(unsafe {
2303 set(
2304 this.desc,
2305 sampler_type,
2306 dtype.raw(),
2307 dims.len() as core::ffi::c_int,
2308 dims.as_ptr(),
2309 )
2310 })?;
2311 Ok(this)
2312 }
2313
2314 #[inline]
2316 pub fn as_raw(&self) -> cudnnSpatialTransformerDescriptor_t {
2317 self.desc
2318 }
2319}
2320
2321impl Drop for SpatialTransformerDescriptor {
2322 fn drop(&mut self) {
2323 if let Ok(c) = cudnn() {
2324 if let Ok(cu) = c.cudnn_destroy_spatial_transformer_descriptor() {
2325 let _ = unsafe { cu(self.desc) };
2326 }
2327 }
2328 }
2329}
2330
2331pub fn spatial_tf_grid_generator<T: DeviceRepr>(
2333 handle: &Handle,
2334 st: &SpatialTransformerDescriptor,
2335 theta: &DeviceBuffer<T>,
2336 grid: &mut DeviceBuffer<T>,
2337) -> Result<()> {
2338 let c = cudnn()?;
2339 let cu = c.cudnn_spatial_tf_grid_generator_forward()?;
2340 check(unsafe {
2341 cu(
2342 handle.handle,
2343 st.desc,
2344 theta.as_raw().0 as *const core::ffi::c_void,
2345 grid.as_raw().0 as *mut core::ffi::c_void,
2346 )
2347 })
2348}
2349
2350#[allow(clippy::too_many_arguments)]
2352pub fn spatial_tf_sampler<T: DeviceRepr>(
2353 handle: &Handle,
2354 st: &SpatialTransformerDescriptor,
2355 alpha: f32,
2356 x_desc: &TensorDescriptor,
2357 x: &DeviceBuffer<T>,
2358 grid: &DeviceBuffer<T>,
2359 beta: f32,
2360 y_desc: &TensorDescriptor,
2361 y: &mut DeviceBuffer<T>,
2362) -> Result<()> {
2363 let c = cudnn()?;
2364 let cu = c.cudnn_spatial_tf_sampler_forward()?;
2365 check(unsafe {
2366 cu(
2367 handle.handle,
2368 st.desc,
2369 &alpha as *const f32 as *const core::ffi::c_void,
2370 x_desc.desc,
2371 x.as_raw().0 as *const core::ffi::c_void,
2372 grid.as_raw().0 as *const core::ffi::c_void,
2373 &beta as *const f32 as *const core::ffi::c_void,
2374 y_desc.desc,
2375 y.as_raw().0 as *mut core::ffi::c_void,
2376 )
2377 })
2378}
2379
2380#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
2387pub enum MathType {
2388 #[default]
2390 Default,
2391 TensorOp,
2393 TensorOpAllowConversion,
2395 FmaOnly,
2397}
2398
2399impl MathType {
2400 pub(crate) fn raw(self) -> cudnnMathType_t {
2401 match self {
2402 MathType::Default => cudnnMathType_t::DefaultMath,
2403 MathType::TensorOp => cudnnMathType_t::TensorOpMath,
2404 MathType::TensorOpAllowConversion => cudnnMathType_t::TensorOpMathAllowConversion,
2405 MathType::FmaOnly => cudnnMathType_t::FmaMath,
2406 }
2407 }
2408 pub(crate) fn from_raw(raw: cudnnMathType_t) -> Self {
2409 match raw {
2410 cudnnMathType_t::DefaultMath => MathType::Default,
2411 cudnnMathType_t::TensorOpMath => MathType::TensorOp,
2412 cudnnMathType_t::TensorOpMathAllowConversion => MathType::TensorOpAllowConversion,
2413 cudnnMathType_t::FmaMath => MathType::FmaOnly,
2414 }
2415 }
2416}
2417
2418#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
2420pub enum ReorderType {
2421 #[default]
2423 Default,
2424 None,
2426}
2427
2428impl ReorderType {
2429 pub(crate) fn raw(self) -> cudnnReorderType_t {
2430 match self {
2431 ReorderType::Default => cudnnReorderType_t::DefaultReorder,
2432 ReorderType::None => cudnnReorderType_t::NoReorder,
2433 }
2434 }
2435 pub(crate) fn from_raw(raw: cudnnReorderType_t) -> Self {
2436 match raw {
2437 cudnnReorderType_t::DefaultReorder => ReorderType::Default,
2438 cudnnReorderType_t::NoReorder => ReorderType::None,
2439 }
2440 }
2441}
2442
2443#[allow(clippy::too_many_arguments)]
2449pub unsafe fn reorder_filter_and_bias(
2450 handle: &Handle,
2451 filter_desc: &FilterDescriptor,
2452 reorder: ReorderType,
2453 filter_data: *const core::ffi::c_void,
2454 reordered_filter: *mut core::ffi::c_void,
2455 reorder_bias: bool,
2456 bias_data: *const core::ffi::c_void,
2457 reordered_bias: *mut core::ffi::c_void,
2458) -> Result<()> { unsafe {
2459 let c = cudnn()?;
2460 let f = c.cudnn_reorder_filter_and_bias()?;
2461 check(f(
2462 handle.handle, filter_desc.desc, reorder.raw(),
2463 filter_data, reordered_filter,
2464 reorder_bias as core::ffi::c_int, bias_data, reordered_bias,
2465 ))
2466}}
2467
2468#[allow(clippy::too_many_arguments)]
2472pub fn convolution_bias_activation_forward<T: DeviceRepr>(
2473 handle: &Handle,
2474 alpha1: f32,
2475 x_desc: &TensorDescriptor, x: &DeviceBuffer<T>,
2476 w_desc: &FilterDescriptor, w: &DeviceBuffer<T>,
2477 conv: &ConvolutionDescriptor,
2478 algo: FwdAlgo,
2479 workspace: &mut DeviceBuffer<u8>,
2480 alpha2: f32,
2481 z_desc: &TensorDescriptor, z: &DeviceBuffer<T>,
2482 bias_desc: &TensorDescriptor, bias: &DeviceBuffer<T>,
2483 activation: &ActivationDescriptor,
2484 y_desc: &TensorDescriptor, y: &mut DeviceBuffer<T>,
2485) -> Result<()> {
2486 let c = cudnn()?;
2487 let cu = c.cudnn_convolution_bias_activation_forward()?;
2488 check(unsafe {
2489 cu(
2490 handle.handle,
2491 &alpha1 as *const f32 as *const core::ffi::c_void,
2492 x_desc.desc, x.as_raw().0 as *const core::ffi::c_void,
2493 w_desc.desc, w.as_raw().0 as *const core::ffi::c_void,
2494 conv.desc, algo.raw(),
2495 workspace.as_raw().0 as *mut core::ffi::c_void, workspace.byte_size(),
2496 &alpha2 as *const f32 as *const core::ffi::c_void,
2497 z_desc.desc, z.as_raw().0 as *const core::ffi::c_void,
2498 bias_desc.desc, bias.as_raw().0 as *const core::ffi::c_void,
2499 activation.desc,
2500 y_desc.desc, y.as_raw().0 as *mut core::ffi::c_void,
2501 )
2502 })
2503}
2504
2505#[allow(clippy::too_many_arguments)]
2507pub fn activation_backward<T: DeviceRepr>(
2508 handle: &Handle,
2509 activation: &ActivationDescriptor,
2510 alpha: f32,
2511 y_desc: &TensorDescriptor, y: &DeviceBuffer<T>,
2512 dy_desc: &TensorDescriptor, dy: &DeviceBuffer<T>,
2513 x_desc: &TensorDescriptor, x: &DeviceBuffer<T>,
2514 beta: f32,
2515 dx_desc: &TensorDescriptor, dx: &mut DeviceBuffer<T>,
2516) -> Result<()> {
2517 let c = cudnn()?;
2518 let cu = c.cudnn_activation_backward()?;
2519 check(unsafe {
2520 cu(
2521 handle.handle, activation.desc,
2522 &alpha as *const f32 as *const core::ffi::c_void,
2523 y_desc.desc, y.as_raw().0 as *const core::ffi::c_void,
2524 dy_desc.desc, dy.as_raw().0 as *const core::ffi::c_void,
2525 x_desc.desc, x.as_raw().0 as *const core::ffi::c_void,
2526 &beta as *const f32 as *const core::ffi::c_void,
2527 dx_desc.desc, dx.as_raw().0 as *mut core::ffi::c_void,
2528 )
2529 })
2530}
2531
2532#[allow(clippy::too_many_arguments)]
2534pub fn lrn_cross_channel_backward<T: DeviceRepr>(
2535 handle: &Handle, lrn: &LrnDescriptor, mode: i32,
2536 alpha: f32,
2537 y_desc: &TensorDescriptor, y: &DeviceBuffer<T>,
2538 dy_desc: &TensorDescriptor, dy: &DeviceBuffer<T>,
2539 x_desc: &TensorDescriptor, x: &DeviceBuffer<T>,
2540 beta: f32,
2541 dx_desc: &TensorDescriptor, dx: &mut DeviceBuffer<T>,
2542) -> Result<()> {
2543 let c = cudnn()?;
2544 let cu = c.cudnn_lrn_cross_channel_backward()?;
2545 check(unsafe {
2546 cu(
2547 handle.handle, lrn.desc, mode,
2548 &alpha as *const f32 as *const core::ffi::c_void,
2549 y_desc.desc, y.as_raw().0 as *const core::ffi::c_void,
2550 dy_desc.desc, dy.as_raw().0 as *const core::ffi::c_void,
2551 x_desc.desc, x.as_raw().0 as *const core::ffi::c_void,
2552 &beta as *const f32 as *const core::ffi::c_void,
2553 dx_desc.desc, dx.as_raw().0 as *mut core::ffi::c_void,
2554 )
2555 })
2556}
2557
2558pub fn reduction_indices_size(
2560 handle: &Handle,
2561 reducer: &ReduceTensorDescriptor,
2562 a: &TensorDescriptor,
2563 c: &TensorDescriptor,
2564) -> Result<usize> {
2565 let cu = cudnn()?;
2566 let q = cu.cudnn_get_reduction_indices_size()?;
2567 let mut size = 0usize;
2568 check(unsafe { q(handle.handle, reducer.desc, a.desc, c.desc, &mut size) })?;
2569 Ok(size)
2570}
2571
2572impl ActivationDescriptor {
2573 pub fn set_swish_beta(&self, beta: f64) -> Result<()> {
2575 let c = cudnn()?;
2576 let f = c.cudnn_set_activation_descriptor_swish_beta()?;
2577 check(unsafe { f(self.desc, beta) })
2578 }
2579 pub fn swish_beta(&self) -> Result<f64> {
2581 let c = cudnn()?;
2582 let f = c.cudnn_get_activation_descriptor_swish_beta()?;
2583 let mut b: f64 = 0.0;
2584 check(unsafe { f(self.desc, &mut b) })?;
2585 Ok(b)
2586 }
2587}
2588
2589pub use baracuda_cudnn_sys::cudnnConvolutionFwdAlgoPerf_t as FwdAlgoPerf;
2595pub use baracuda_cudnn_sys::cudnnConvolutionBwdDataAlgoPerf_t as BwdDataAlgoPerf;
2597pub use baracuda_cudnn_sys::cudnnConvolutionBwdFilterAlgoPerf_t as BwdFilterAlgoPerf;
2599
2600pub fn get_convolution_forward_algorithm(
2602 handle: &Handle,
2603 src: &TensorDescriptor, filter: &FilterDescriptor,
2604 conv: &ConvolutionDescriptor, dst: &TensorDescriptor,
2605 requested: i32,
2606) -> Result<Vec<FwdAlgoPerf>> {
2607 let cu = cudnn()?;
2608 let f = cu.cudnn_get_convolution_forward_algorithm_v7()?;
2609 let mut returned: core::ffi::c_int = 0;
2610 let mut buf: Vec<FwdAlgoPerf> = Vec::with_capacity(requested as usize);
2611 let raw = unsafe {
2612 f(handle.handle, src.desc, filter.desc, conv.desc, dst.desc,
2613 requested, &mut returned, buf.as_mut_ptr())
2614 };
2615 check(raw)?;
2616 unsafe { buf.set_len(returned as usize); }
2617 Ok(buf)
2618}
2619
2620pub fn find_convolution_forward_algorithm(
2622 handle: &Handle,
2623 src: &TensorDescriptor, filter: &FilterDescriptor,
2624 conv: &ConvolutionDescriptor, dst: &TensorDescriptor,
2625 requested: i32,
2626) -> Result<Vec<FwdAlgoPerf>> {
2627 let cu = cudnn()?;
2628 let f = cu.cudnn_find_convolution_forward_algorithm()?;
2629 let mut returned: core::ffi::c_int = 0;
2630 let mut buf: Vec<FwdAlgoPerf> = Vec::with_capacity(requested as usize);
2631 let raw = unsafe {
2632 f(handle.handle, src.desc, filter.desc, conv.desc, dst.desc,
2633 requested, &mut returned, buf.as_mut_ptr())
2634 };
2635 check(raw)?;
2636 unsafe { buf.set_len(returned as usize); }
2637 Ok(buf)
2638}
2639
2640pub fn get_convolution_backward_data_algorithm(
2642 handle: &Handle,
2643 filter: &FilterDescriptor, diff: &TensorDescriptor,
2644 conv: &ConvolutionDescriptor, grad: &TensorDescriptor,
2645 requested: i32,
2646) -> Result<Vec<BwdDataAlgoPerf>> {
2647 let cu = cudnn()?;
2648 let f = cu.cudnn_get_convolution_backward_data_algorithm_v7()?;
2649 let mut returned: core::ffi::c_int = 0;
2650 let mut buf: Vec<BwdDataAlgoPerf> = Vec::with_capacity(requested as usize);
2651 let raw = unsafe {
2652 f(handle.handle, filter.desc, diff.desc, conv.desc, grad.desc,
2653 requested, &mut returned, buf.as_mut_ptr())
2654 };
2655 check(raw)?;
2656 unsafe { buf.set_len(returned as usize); }
2657 Ok(buf)
2658}
2659
2660pub fn get_convolution_backward_filter_algorithm(
2662 handle: &Handle,
2663 src: &TensorDescriptor, diff: &TensorDescriptor,
2664 conv: &ConvolutionDescriptor, grad: &FilterDescriptor,
2665 requested: i32,
2666) -> Result<Vec<BwdFilterAlgoPerf>> {
2667 let cu = cudnn()?;
2668 let f = cu.cudnn_get_convolution_backward_filter_algorithm_v7()?;
2669 let mut returned: core::ffi::c_int = 0;
2670 let mut buf: Vec<BwdFilterAlgoPerf> = Vec::with_capacity(requested as usize);
2671 let raw = unsafe {
2672 f(handle.handle, src.desc, diff.desc, conv.desc, grad.desc,
2673 requested, &mut returned, buf.as_mut_ptr())
2674 };
2675 check(raw)?;
2676 unsafe { buf.set_len(returned as usize); }
2677 Ok(buf)
2678}
2679
2680#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
2686pub enum NormMode {
2687 PerActivation,
2689 #[default]
2691 PerChannel,
2692}
2693impl NormMode {
2694 fn raw(self) -> cudnnNormMode_t {
2695 match self {
2696 NormMode::PerActivation => cudnnNormMode_t::PerActivation,
2697 NormMode::PerChannel => cudnnNormMode_t::PerChannel,
2698 }
2699 }
2700}
2701
2702#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
2704pub enum NormAlgo {
2705 #[default]
2707 Standard,
2708 Persist,
2710}
2711impl NormAlgo {
2712 fn raw(self) -> cudnnNormAlgo_t {
2713 match self {
2714 NormAlgo::Standard => cudnnNormAlgo_t::Standard,
2715 NormAlgo::Persist => cudnnNormAlgo_t::Persist,
2716 }
2717 }
2718}
2719
2720#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
2722pub enum NormOp {
2723 #[default]
2725 Norm,
2726 NormActivation,
2728 NormAddActivation,
2730}
2731impl NormOp {
2732 fn raw(self) -> cudnnNormOps_t {
2733 match self {
2734 NormOp::Norm => cudnnNormOps_t::Norm,
2735 NormOp::NormActivation => cudnnNormOps_t::NormActivation,
2736 NormOp::NormAddActivation => cudnnNormOps_t::NormAddActivation,
2737 }
2738 }
2739}
2740
2741#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
2743pub enum BnOp {
2744 #[default]
2746 Bn,
2747 BnActivation,
2749 BnAddActivation,
2751}
2752impl BnOp {
2753 fn raw(self) -> cudnnBatchNormOps_t {
2754 match self {
2755 BnOp::Bn => cudnnBatchNormOps_t::Bn,
2756 BnOp::BnActivation => cudnnBatchNormOps_t::BnActivation,
2757 BnOp::BnAddActivation => cudnnBatchNormOps_t::BnAddActivation,
2758 }
2759 }
2760}
2761
2762#[allow(clippy::too_many_arguments)]
2764pub fn batch_normalization_forward_training_ex_workspace_size(
2765 handle: &Handle,
2766 mode: BatchNormMode, bn_ops: BnOp,
2767 x: &TensorDescriptor, z: &TensorDescriptor, y: &TensorDescriptor,
2768 bn_smbv: &TensorDescriptor, activation: &ActivationDescriptor,
2769) -> Result<usize> {
2770 let cu = cudnn()?;
2771 let f = cu.cudnn_get_batch_normalization_forward_training_ex_workspace_size()?;
2772 let mut size = 0usize;
2773 check(unsafe {
2774 f(handle.handle, mode.raw(), bn_ops.raw(),
2775 x.desc, z.desc, y.desc, bn_smbv.desc, activation.desc, &mut size)
2776 })?;
2777 Ok(size)
2778}
2779
2780#[allow(clippy::too_many_arguments)]
2782pub fn batch_normalization_backward_ex_workspace_size(
2783 handle: &Handle,
2784 mode: BatchNormMode, bn_ops: BnOp,
2785 x: &TensorDescriptor, y: &TensorDescriptor, dy: &TensorDescriptor,
2786 dz: &TensorDescriptor, dx: &TensorDescriptor,
2787 d_bn_scale_bias: &TensorDescriptor, activation: &ActivationDescriptor,
2788) -> Result<usize> {
2789 let cu = cudnn()?;
2790 let f = cu.cudnn_get_batch_normalization_backward_ex_workspace_size()?;
2791 let mut size = 0usize;
2792 check(unsafe {
2793 f(handle.handle, mode.raw(), bn_ops.raw(),
2794 x.desc, y.desc, dy.desc, dz.desc, dx.desc,
2795 d_bn_scale_bias.desc, activation.desc, &mut size)
2796 })?;
2797 Ok(size)
2798}
2799
2800pub fn batch_normalization_training_ex_reserve_space_size(
2802 handle: &Handle,
2803 mode: BatchNormMode, bn_ops: BnOp,
2804 activation: &ActivationDescriptor, x: &TensorDescriptor,
2805) -> Result<usize> {
2806 let cu = cudnn()?;
2807 let f = cu.cudnn_get_batch_normalization_training_ex_reserve_space_size()?;
2808 let mut size = 0usize;
2809 check(unsafe {
2810 f(handle.handle, mode.raw(), bn_ops.raw(), activation.desc, x.desc, &mut size)
2811 })?;
2812 Ok(size)
2813}
2814
2815pub struct RnnDescriptor {
2821 desc: baracuda_cudnn_sys::cudnnRNNDescriptor_t,
2822}
2823unsafe impl Send for RnnDescriptor {}
2824impl core::fmt::Debug for RnnDescriptor {
2825 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2826 f.debug_struct("RnnDescriptor").field("desc", &self.desc).finish_non_exhaustive()
2827 }
2828}
2829impl RnnDescriptor {
2830 pub fn new() -> Result<Self> {
2833 let c = cudnn()?;
2834 let create = c.cudnn_create_rnn_descriptor()?;
2835 let mut desc: baracuda_cudnn_sys::cudnnRNNDescriptor_t = core::ptr::null_mut();
2836 check(unsafe { create(&mut desc) })?;
2837 Ok(Self { desc })
2838 }
2839
2840 #[allow(clippy::too_many_arguments)]
2843 pub fn set_v8(
2844 &self,
2845 algo: i32, cell_mode: i32, bias_mode: i32,
2846 dir_mode: i32, input_mode: i32,
2847 data_type: DType, math_prec: DType, math_type: MathType,
2848 input_size: i32, hidden_size: i32, proj_size: i32, num_layers: i32,
2849 dropout: &DropoutDescriptor, aux_flags: u32,
2850 ) -> Result<()> {
2851 use baracuda_cudnn_sys::{cudnnDirectionMode_t, cudnnRNNAlgo_t, cudnnRNNInputMode_t, cudnnRNNMode_t};
2852 let c = cudnn()?;
2853 let f = c.cudnn_set_rnn_descriptor_v8()?;
2854 let algo = match algo {
2855 0 => cudnnRNNAlgo_t::Standard,
2856 1 => cudnnRNNAlgo_t::PersistStatic,
2857 2 => cudnnRNNAlgo_t::PersistDynamic,
2858 _ => cudnnRNNAlgo_t::PersistStaticSmallH,
2859 };
2860 let cell = match cell_mode {
2861 0 => cudnnRNNMode_t::ReluRnn,
2862 1 => cudnnRNNMode_t::TanhRnn,
2863 2 => cudnnRNNMode_t::Lstm,
2864 _ => cudnnRNNMode_t::Gru,
2865 };
2866 let dir = if dir_mode == 1 { cudnnDirectionMode_t::Bidirectional } else { cudnnDirectionMode_t::Unidirectional };
2867 let im = if input_mode == 1 { cudnnRNNInputMode_t::SkipInput } else { cudnnRNNInputMode_t::LinearInput };
2868 check(unsafe {
2869 f(self.desc, algo, cell, bias_mode, dir, im,
2870 data_type.raw(), math_prec.raw(), math_type.raw(),
2871 input_size, hidden_size, proj_size, num_layers,
2872 dropout.desc, aux_flags)
2873 })
2874 }
2875
2876 #[inline]
2878 pub fn as_raw(&self) -> baracuda_cudnn_sys::cudnnRNNDescriptor_t { self.desc }
2879}
2880impl Drop for RnnDescriptor {
2881 fn drop(&mut self) {
2882 if let Ok(c) = cudnn() {
2883 if let Ok(cu) = c.cudnn_destroy_rnn_descriptor() {
2884 let _ = unsafe { cu(self.desc) };
2885 }
2886 }
2887 }
2888}
2889
2890pub struct RnnDataDescriptor {
2892 desc: baracuda_cudnn_sys::cudnnRNNDataDescriptor_t,
2893}
2894unsafe impl Send for RnnDataDescriptor {}
2895impl core::fmt::Debug for RnnDataDescriptor {
2896 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2897 f.debug_struct("RnnDataDescriptor").field("desc", &self.desc).finish_non_exhaustive()
2898 }
2899}
2900impl RnnDataDescriptor {
2901 pub fn new() -> Result<Self> {
2904 let c = cudnn()?;
2905 let create = c.cudnn_create_rnn_data_descriptor()?;
2906 let mut desc: baracuda_cudnn_sys::cudnnRNNDataDescriptor_t = core::ptr::null_mut();
2907 check(unsafe { create(&mut desc) })?;
2908 Ok(Self { desc })
2909 }
2910 #[inline]
2912 pub fn as_raw(&self) -> baracuda_cudnn_sys::cudnnRNNDataDescriptor_t { self.desc }
2913}
2914impl Drop for RnnDataDescriptor {
2915 fn drop(&mut self) {
2916 if let Ok(c) = cudnn() {
2917 if let Ok(cu) = c.cudnn_destroy_rnn_data_descriptor() {
2918 let _ = unsafe { cu(self.desc) };
2919 }
2920 }
2921 }
2922}
2923
2924pub fn build_rnn_dynamic(handle: &Handle, rnn: &RnnDescriptor, mini_batch: i32) -> Result<()> {
2926 let c = cudnn()?;
2927 let f = c.cudnn_build_rnn_dynamic()?;
2928 check(unsafe { f(handle.handle, rnn.desc, mini_batch) })
2929}
2930
2931pub fn rnn_temp_space_sizes(
2934 handle: &Handle, rnn: &RnnDescriptor, fwd_mode: i32, x: &RnnDataDescriptor,
2935) -> Result<(usize, usize)> {
2936 let c = cudnn()?;
2937 let f = c.cudnn_get_rnn_temp_space_sizes()?;
2938 let (mut ws, mut rs) = (0usize, 0usize);
2939 check(unsafe { f(handle.handle, rnn.desc, fwd_mode, x.desc, &mut ws, &mut rs) })?;
2940 Ok((ws, rs))
2941}
2942
2943pub fn rnn_weight_space_size(handle: &Handle, rnn: &RnnDescriptor) -> Result<usize> {
2945 let c = cudnn()?;
2946 let f = c.cudnn_get_rnn_weight_space_size()?;
2947 let mut size = 0usize;
2948 check(unsafe { f(handle.handle, rnn.desc, &mut size) })?;
2949 Ok(size)
2950}
2951
2952pub struct AttnDescriptor {
2958 desc: cudnnAttnDescriptor_t,
2959}
2960unsafe impl Send for AttnDescriptor {}
2961impl core::fmt::Debug for AttnDescriptor {
2962 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2963 f.debug_struct("AttnDescriptor").field("desc", &self.desc).finish_non_exhaustive()
2964 }
2965}
2966impl AttnDescriptor {
2967 pub fn new() -> Result<Self> {
2971 let c = cudnn()?;
2972 let cu = c.cudnn_create_attn_descriptor()?;
2973 let mut desc: cudnnAttnDescriptor_t = core::ptr::null_mut();
2974 check(unsafe { cu(&mut desc) })?;
2975 Ok(Self { desc })
2976 }
2977
2978 #[allow(clippy::too_many_arguments)]
2981 pub fn set(
2982 &self,
2983 attn_mode: u32, n_heads: i32, sm_scaler: f64,
2984 data_type: DType, compute_prec: DType, math_type: MathType,
2985 attn_dropout: &DropoutDescriptor, post_dropout: &DropoutDescriptor,
2986 q_size: i32, k_size: i32, v_size: i32,
2987 q_proj_size: i32, k_proj_size: i32, v_proj_size: i32, o_proj_size: i32,
2988 qo_max_seq_length: i32, kv_max_seq_length: i32,
2989 max_batch_size: i32, max_beam_size: i32,
2990 ) -> Result<()> {
2991 let c = cudnn()?;
2992 let f = c.cudnn_set_attn_descriptor()?;
2993 check(unsafe {
2994 f(self.desc, attn_mode, n_heads, sm_scaler,
2995 data_type.raw(), compute_prec.raw(), math_type.raw(),
2996 attn_dropout.desc, post_dropout.desc,
2997 q_size, k_size, v_size,
2998 q_proj_size, k_proj_size, v_proj_size, o_proj_size,
2999 qo_max_seq_length, kv_max_seq_length,
3000 max_batch_size, max_beam_size)
3001 })
3002 }
3003
3004 #[inline]
3006 pub fn as_raw(&self) -> cudnnAttnDescriptor_t { self.desc }
3007}
3008impl Drop for AttnDescriptor {
3009 fn drop(&mut self) {
3010 if let Ok(c) = cudnn() {
3011 if let Ok(cu) = c.cudnn_destroy_attn_descriptor() {
3012 let _ = unsafe { cu(self.desc) };
3013 }
3014 }
3015 }
3016}
3017
3018pub fn multi_head_attn_buffers(
3020 handle: &Handle, attn: &AttnDescriptor,
3021) -> Result<(usize, usize, usize)> {
3022 let c = cudnn()?;
3023 let f = c.cudnn_get_multi_head_attn_buffers()?;
3024 let (mut w, mut ws, mut rs) = (0usize, 0usize, 0usize);
3025 check(unsafe { f(handle.handle, attn.desc, &mut w, &mut ws, &mut rs) })?;
3026 Ok((w, ws, rs))
3027}
3028
3029pub struct SeqDataDescriptor {
3031 desc: cudnnSeqDataDescriptor_t,
3032}
3033unsafe impl Send for SeqDataDescriptor {}
3034impl core::fmt::Debug for SeqDataDescriptor {
3035 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
3036 f.debug_struct("SeqDataDescriptor").field("desc", &self.desc).finish_non_exhaustive()
3037 }
3038}
3039impl SeqDataDescriptor {
3040 pub fn new() -> Result<Self> {
3043 let c = cudnn()?;
3044 let cu = c.cudnn_create_seq_data_descriptor()?;
3045 let mut desc: cudnnSeqDataDescriptor_t = core::ptr::null_mut();
3046 check(unsafe { cu(&mut desc) })?;
3047 Ok(Self { desc })
3048 }
3049
3050 #[allow(clippy::too_many_arguments)]
3053 pub unsafe fn set(
3054 &self,
3055 data_type: DType,
3056 dim_a: &[i32], axes: &[i32], seq_length_array: &[i32],
3057 padding_fill: *const core::ffi::c_void,
3058 ) -> Result<()> { unsafe {
3059 let c = cudnn()?;
3060 let f = c.cudnn_set_seq_data_descriptor()?;
3061 check(f(
3062 self.desc, data_type.raw(),
3063 dim_a.len() as core::ffi::c_int,
3064 dim_a.as_ptr(), axes.as_ptr(),
3065 seq_length_array.len(), seq_length_array.as_ptr(),
3066 padding_fill,
3067 ))
3068 }}
3069
3070 #[inline]
3072 pub fn as_raw(&self) -> cudnnSeqDataDescriptor_t { self.desc }
3073}
3074impl Drop for SeqDataDescriptor {
3075 fn drop(&mut self) {
3076 if let Ok(c) = cudnn() {
3077 if let Ok(cu) = c.cudnn_destroy_seq_data_descriptor() {
3078 let _ = unsafe { cu(self.desc) };
3079 }
3080 }
3081 }
3082}
3083
3084pub use baracuda_cudnn_sys::{cudnnMathType_t as RawMathType, cudnnReorderType_t as RawReorderType};
3086
3087impl TensorDescriptor {
3092 #[allow(clippy::too_many_arguments)]
3095 pub fn new_4d_ex(
3096 dtype: DType,
3097 n: i32, c: i32, h: i32, w: i32,
3098 n_stride: i32, c_stride: i32, h_stride: i32, w_stride: i32,
3099 ) -> Result<Self> {
3100 let cu = cudnn()?;
3101 let create = cu.cudnn_create_tensor_descriptor()?;
3102 let set = cu.cudnn_set_tensor_4d_descriptor_ex()?;
3103 let mut desc: cudnnTensorDescriptor_t = core::ptr::null_mut();
3104 check(unsafe { create(&mut desc) })?;
3105 let this = Self { desc };
3106 check(unsafe {
3107 set(this.desc, dtype.raw(), n, c, h, w,
3108 n_stride, c_stride, h_stride, w_stride)
3109 })?;
3110 Ok(this)
3111 }
3112
3113 #[allow(clippy::type_complexity)]
3116 pub fn get_4d(&self) -> Result<(DType, i32, i32, i32, i32, i32, i32, i32, i32)> {
3117 let cu = cudnn()?;
3118 let f = cu.cudnn_get_tensor_4d_descriptor()?;
3119 let mut dt = cudnnDataType_t::Float;
3120 let (mut n, mut c, mut h, mut w) = (0i32, 0i32, 0i32, 0i32);
3121 let (mut ns, mut cs, mut hs, mut ws) = (0i32, 0i32, 0i32, 0i32);
3122 check(unsafe {
3123 f(self.desc, &mut dt, &mut n, &mut c, &mut h, &mut w,
3124 &mut ns, &mut cs, &mut hs, &mut ws)
3125 })?;
3126 let dtype = match dt {
3127 cudnnDataType_t::Float => DType::F32,
3128 cudnnDataType_t::Double => DType::F64,
3129 cudnnDataType_t::Half => DType::F16,
3130 cudnnDataType_t::BFloat16 => DType::BF16,
3131 cudnnDataType_t::Int8 => DType::I8,
3132 cudnnDataType_t::Int32 => DType::I32,
3133 _ => DType::F32,
3134 };
3135 Ok((dtype, n, c, h, w, ns, cs, hs, ws))
3136 }
3137}
3138
3139impl FilterDescriptor {
3140 pub fn get_4d(&self) -> Result<(DType, TensorFormat, i32, i32, i32, i32)> {
3142 let cu = cudnn()?;
3143 let f = cu.cudnn_get_filter_4d_descriptor()?;
3144 let mut dt = cudnnDataType_t::Float;
3145 let mut fmt = cudnnTensorFormat_t::Nchw;
3146 let (mut k, mut c, mut h, mut w) = (0i32, 0i32, 0i32, 0i32);
3147 check(unsafe {
3148 f(self.desc, &mut dt, &mut fmt, &mut k, &mut c, &mut h, &mut w)
3149 })?;
3150 let dtype = match dt {
3151 cudnnDataType_t::Float => DType::F32,
3152 cudnnDataType_t::Double => DType::F64,
3153 cudnnDataType_t::Half => DType::F16,
3154 cudnnDataType_t::BFloat16 => DType::BF16,
3155 cudnnDataType_t::Int8 => DType::I8,
3156 cudnnDataType_t::Int32 => DType::I32,
3157 _ => DType::F32,
3158 };
3159 let format = match fmt {
3160 cudnnTensorFormat_t::Nchw => TensorFormat::Nchw,
3161 cudnnTensorFormat_t::Nhwc => TensorFormat::Nhwc,
3162 _ => TensorFormat::Nchw,
3163 };
3164 Ok((dtype, format, k, c, h, w))
3165 }
3166}
3167
3168impl DropoutDescriptor {
3169 pub fn get(&self, handle: &Handle) -> Result<(f32, *mut core::ffi::c_void, u64)> {
3172 let cu = cudnn()?;
3173 let f = cu.cudnn_get_dropout_descriptor()?;
3174 let mut dropout: f32 = 0.0;
3175 let mut states: *mut core::ffi::c_void = core::ptr::null_mut();
3176 let mut seed: u64 = 0;
3177 check(unsafe { f(self.desc, handle.handle, &mut dropout, &mut states, &mut seed) })?;
3178 Ok((dropout, states, seed))
3179 }
3180
3181 pub unsafe fn restore(
3188 &self, handle: &Handle, dropout: f32,
3189 states: *mut core::ffi::c_void, state_size: usize, seed: u64,
3190 ) -> Result<()> { unsafe {
3191 let cu = cudnn()?;
3192 let f = cu.cudnn_restore_dropout_descriptor()?;
3193 check(f(self.desc, handle.handle, dropout, states, state_size, seed))
3194 }}
3195}
3196
3197#[allow(clippy::too_many_arguments)]
3204pub fn batch_normalization_forward_training_ex<T: DeviceRepr>(
3205 handle: &Handle,
3206 mode: BatchNormMode, bn_ops: BnOp,
3207 alpha: f32, beta: f32,
3208 x_desc: &TensorDescriptor, x: &DeviceBuffer<T>,
3209 z_desc: &TensorDescriptor, z: &DeviceBuffer<T>,
3210 y_desc: &TensorDescriptor, y: &mut DeviceBuffer<T>,
3211 bn_smbv_desc: &TensorDescriptor,
3212 bn_scale: &DeviceBuffer<T>, bn_bias: &DeviceBuffer<T>,
3213 exponential_avg_factor: f64,
3214 running_mean: &mut DeviceBuffer<T>, running_var: &mut DeviceBuffer<T>,
3215 epsilon: f64,
3216 saved_mean: &mut DeviceBuffer<T>, saved_inv_var: &mut DeviceBuffer<T>,
3217 activation: &ActivationDescriptor,
3218 workspace: &mut DeviceBuffer<u8>, reserve: &mut DeviceBuffer<u8>,
3219) -> Result<()> {
3220 let c = cudnn()?;
3221 let cu = c.cudnn_batch_normalization_forward_training_ex()?;
3222 check(unsafe {
3223 cu(
3224 handle.handle, mode.raw(), bn_ops.raw(),
3225 &alpha as *const f32 as *const core::ffi::c_void,
3226 &beta as *const f32 as *const core::ffi::c_void,
3227 x_desc.desc, x.as_raw().0 as *const core::ffi::c_void,
3228 z_desc.desc, z.as_raw().0 as *const core::ffi::c_void,
3229 y_desc.desc, y.as_raw().0 as *mut core::ffi::c_void,
3230 bn_smbv_desc.desc,
3231 bn_scale.as_raw().0 as *const core::ffi::c_void,
3232 bn_bias.as_raw().0 as *const core::ffi::c_void,
3233 exponential_avg_factor,
3234 running_mean.as_raw().0 as *mut core::ffi::c_void,
3235 running_var.as_raw().0 as *mut core::ffi::c_void,
3236 epsilon,
3237 saved_mean.as_raw().0 as *mut core::ffi::c_void,
3238 saved_inv_var.as_raw().0 as *mut core::ffi::c_void,
3239 activation.desc,
3240 workspace.as_raw().0 as *mut core::ffi::c_void, workspace.byte_size(),
3241 reserve.as_raw().0 as *mut core::ffi::c_void, reserve.byte_size(),
3242 )
3243 })
3244}
3245
3246#[allow(clippy::too_many_arguments)]
3248pub fn batch_normalization_backward_ex<T: DeviceRepr>(
3249 handle: &Handle,
3250 mode: BatchNormMode, bn_ops: BnOp,
3251 alpha_data: f32, beta_data: f32,
3252 alpha_param: f32, beta_param: f32,
3253 x_desc: &TensorDescriptor, x: &DeviceBuffer<T>,
3254 y_desc: &TensorDescriptor, y: &DeviceBuffer<T>,
3255 dy_desc: &TensorDescriptor, dy: &DeviceBuffer<T>,
3256 dz_desc: &TensorDescriptor, dz: &mut DeviceBuffer<T>,
3257 dx_desc: &TensorDescriptor, dx: &mut DeviceBuffer<T>,
3258 d_bn_scale_bias_desc: &TensorDescriptor,
3259 bn_scale: &DeviceBuffer<T>, bn_bias: &DeviceBuffer<T>,
3260 d_bn_scale: &mut DeviceBuffer<T>, d_bn_bias: &mut DeviceBuffer<T>,
3261 epsilon: f64,
3262 saved_mean: &DeviceBuffer<T>, saved_inv_var: &DeviceBuffer<T>,
3263 activation: &ActivationDescriptor,
3264 workspace: &mut DeviceBuffer<u8>, reserve: &mut DeviceBuffer<u8>,
3265) -> Result<()> {
3266 let c = cudnn()?;
3267 let cu = c.cudnn_batch_normalization_backward_ex()?;
3268 check(unsafe {
3269 cu(
3270 handle.handle, mode.raw(), bn_ops.raw(),
3271 &alpha_data as *const f32 as *const core::ffi::c_void,
3272 &beta_data as *const f32 as *const core::ffi::c_void,
3273 &alpha_param as *const f32 as *const core::ffi::c_void,
3274 &beta_param as *const f32 as *const core::ffi::c_void,
3275 x_desc.desc, x.as_raw().0 as *const core::ffi::c_void,
3276 y_desc.desc, y.as_raw().0 as *const core::ffi::c_void,
3277 dy_desc.desc, dy.as_raw().0 as *const core::ffi::c_void,
3278 dz_desc.desc, dz.as_raw().0 as *mut core::ffi::c_void,
3279 dx_desc.desc, dx.as_raw().0 as *mut core::ffi::c_void,
3280 d_bn_scale_bias_desc.desc,
3281 bn_scale.as_raw().0 as *const core::ffi::c_void,
3282 bn_bias.as_raw().0 as *const core::ffi::c_void,
3283 d_bn_scale.as_raw().0 as *mut core::ffi::c_void,
3284 d_bn_bias.as_raw().0 as *mut core::ffi::c_void,
3285 epsilon,
3286 saved_mean.as_raw().0 as *const core::ffi::c_void,
3287 saved_inv_var.as_raw().0 as *const core::ffi::c_void,
3288 activation.desc,
3289 workspace.as_raw().0 as *mut core::ffi::c_void, workspace.byte_size(),
3290 reserve.as_raw().0 as *mut core::ffi::c_void, reserve.byte_size(),
3291 )
3292 })
3293}
3294
3295#[allow(clippy::too_many_arguments)]
3301pub fn normalization_forward_inference<T: DeviceRepr>(
3302 handle: &Handle,
3303 mode: NormMode, ops: NormOp, algo: NormAlgo,
3304 alpha: f32, beta: f32,
3305 x_desc: &TensorDescriptor, x: &DeviceBuffer<T>,
3306 norm_scale_bias_desc: &TensorDescriptor,
3307 norm_scale: &DeviceBuffer<T>, norm_bias: &DeviceBuffer<T>,
3308 norm_mean_var_desc: &TensorDescriptor,
3309 estimated_mean: &DeviceBuffer<T>, estimated_var: &DeviceBuffer<T>,
3310 z_desc: &TensorDescriptor, z: &DeviceBuffer<T>,
3311 activation: &ActivationDescriptor,
3312 y_desc: &TensorDescriptor, y: &mut DeviceBuffer<T>,
3313 epsilon: f64, group_count: i32,
3314) -> Result<()> {
3315 let c = cudnn()?;
3316 let cu = c.cudnn_normalization_forward_inference()?;
3317 check(unsafe {
3318 cu(
3319 handle.handle, mode.raw(), ops.raw(), algo.raw(),
3320 &alpha as *const f32 as *const core::ffi::c_void,
3321 &beta as *const f32 as *const core::ffi::c_void,
3322 x_desc.desc, x.as_raw().0 as *const core::ffi::c_void,
3323 norm_scale_bias_desc.desc,
3324 norm_scale.as_raw().0 as *const core::ffi::c_void,
3325 norm_bias.as_raw().0 as *const core::ffi::c_void,
3326 norm_mean_var_desc.desc,
3327 estimated_mean.as_raw().0 as *const core::ffi::c_void,
3328 estimated_var.as_raw().0 as *const core::ffi::c_void,
3329 z_desc.desc, z.as_raw().0 as *const core::ffi::c_void,
3330 activation.desc,
3331 y_desc.desc, y.as_raw().0 as *mut core::ffi::c_void,
3332 epsilon, group_count,
3333 )
3334 })
3335}
3336
3337#[allow(clippy::too_many_arguments)]
3339pub fn normalization_forward_training_workspace_size(
3340 handle: &Handle,
3341 mode: NormMode, ops: NormOp, algo: NormAlgo,
3342 x_desc: &TensorDescriptor, z_desc: &TensorDescriptor,
3343 y_desc: &TensorDescriptor, norm_scale_bias_desc: &TensorDescriptor,
3344 activation: &ActivationDescriptor, norm_mean_var_desc: &TensorDescriptor,
3345 group_count: i32,
3346) -> Result<usize> {
3347 let c = cudnn()?;
3348 let f = c.cudnn_get_normalization_forward_training_workspace_size()?;
3349 let mut size = 0usize;
3350 check(unsafe {
3351 f(handle.handle, mode.raw(), ops.raw(), algo.raw(),
3352 x_desc.desc, z_desc.desc, y_desc.desc, norm_scale_bias_desc.desc,
3353 activation.desc, norm_mean_var_desc.desc, &mut size, group_count)
3354 })?;
3355 Ok(size)
3356}
3357
3358#[allow(clippy::too_many_arguments)]
3360pub fn normalization_backward_workspace_size(
3361 handle: &Handle,
3362 mode: NormMode, ops: NormOp, algo: NormAlgo,
3363 x_desc: &TensorDescriptor, y_desc: &TensorDescriptor,
3364 dy_desc: &TensorDescriptor, dz_desc: &TensorDescriptor,
3365 dx_desc: &TensorDescriptor, d_norm_scale_bias_desc: &TensorDescriptor,
3366 activation: &ActivationDescriptor, norm_mean_var_desc: &TensorDescriptor,
3367 group_count: i32,
3368) -> Result<usize> {
3369 let c = cudnn()?;
3370 let f = c.cudnn_get_normalization_backward_workspace_size()?;
3371 let mut size = 0usize;
3372 check(unsafe {
3373 f(handle.handle, mode.raw(), ops.raw(), algo.raw(),
3374 x_desc.desc, y_desc.desc, dy_desc.desc, dz_desc.desc,
3375 dx_desc.desc, d_norm_scale_bias_desc.desc,
3376 activation.desc, norm_mean_var_desc.desc, &mut size, group_count)
3377 })?;
3378 Ok(size)
3379}
3380
3381pub fn normalization_training_reserve_space_size(
3383 handle: &Handle,
3384 mode: NormMode, ops: NormOp, algo: NormAlgo,
3385 activation: &ActivationDescriptor, x_desc: &TensorDescriptor,
3386 group_count: i32,
3387) -> Result<usize> {
3388 let c = cudnn()?;
3389 let f = c.cudnn_get_normalization_training_reserve_space_size()?;
3390 let mut size = 0usize;
3391 check(unsafe {
3392 f(handle.handle, mode.raw(), ops.raw(), algo.raw(),
3393 activation.desc, x_desc.desc, &mut size, group_count)
3394 })?;
3395 Ok(size)
3396}
3397
3398#[allow(clippy::too_many_arguments)]
3400pub fn normalization_forward_training<T: DeviceRepr>(
3401 handle: &Handle,
3402 mode: NormMode, ops: NormOp, algo: NormAlgo,
3403 alpha: f32, beta: f32,
3404 x_desc: &TensorDescriptor, x: &DeviceBuffer<T>,
3405 norm_scale_bias_desc: &TensorDescriptor,
3406 norm_scale: &DeviceBuffer<T>, norm_bias: &DeviceBuffer<T>,
3407 exponential_avg_factor: f64,
3408 norm_mean_var_desc: &TensorDescriptor,
3409 running_mean: &mut DeviceBuffer<T>, running_var: &mut DeviceBuffer<T>,
3410 epsilon: f64,
3411 saved_mean: &mut DeviceBuffer<T>, saved_inv_var: &mut DeviceBuffer<T>,
3412 activation: &ActivationDescriptor,
3413 z_desc: &TensorDescriptor, z: &DeviceBuffer<T>,
3414 y_desc: &TensorDescriptor, y: &mut DeviceBuffer<T>,
3415 workspace: &mut DeviceBuffer<u8>, reserve: &mut DeviceBuffer<u8>,
3416 group_count: i32,
3417) -> Result<()> {
3418 let c = cudnn()?;
3419 let cu = c.cudnn_normalization_forward_training()?;
3420 check(unsafe {
3421 cu(
3422 handle.handle, mode.raw(), ops.raw(), algo.raw(),
3423 &alpha as *const f32 as *const core::ffi::c_void,
3424 &beta as *const f32 as *const core::ffi::c_void,
3425 x_desc.desc, x.as_raw().0 as *const core::ffi::c_void,
3426 norm_scale_bias_desc.desc,
3427 norm_scale.as_raw().0 as *const core::ffi::c_void,
3428 norm_bias.as_raw().0 as *const core::ffi::c_void,
3429 exponential_avg_factor,
3430 norm_mean_var_desc.desc,
3431 running_mean.as_raw().0 as *mut core::ffi::c_void,
3432 running_var.as_raw().0 as *mut core::ffi::c_void,
3433 epsilon,
3434 saved_mean.as_raw().0 as *mut core::ffi::c_void,
3435 saved_inv_var.as_raw().0 as *mut core::ffi::c_void,
3436 activation.desc,
3437 z_desc.desc, z.as_raw().0 as *const core::ffi::c_void,
3438 y_desc.desc, y.as_raw().0 as *mut core::ffi::c_void,
3439 workspace.as_raw().0 as *mut core::ffi::c_void, workspace.byte_size(),
3440 reserve.as_raw().0 as *mut core::ffi::c_void, reserve.byte_size(),
3441 group_count,
3442 )
3443 })
3444}
3445
3446#[allow(clippy::too_many_arguments)]
3448pub fn normalization_backward<T: DeviceRepr>(
3449 handle: &Handle,
3450 mode: NormMode, ops: NormOp, algo: NormAlgo,
3451 alpha_data: f32, beta_data: f32,
3452 alpha_param: f32, beta_param: f32,
3453 x_desc: &TensorDescriptor, x: &DeviceBuffer<T>,
3454 y_desc: &TensorDescriptor, y: &DeviceBuffer<T>,
3455 dy_desc: &TensorDescriptor, dy: &DeviceBuffer<T>,
3456 dz_desc: &TensorDescriptor, dz: &mut DeviceBuffer<T>,
3457 dx_desc: &TensorDescriptor, dx: &mut DeviceBuffer<T>,
3458 d_norm_scale_bias_desc: &TensorDescriptor,
3459 norm_scale: &DeviceBuffer<T>, norm_bias: &DeviceBuffer<T>,
3460 d_norm_scale: &mut DeviceBuffer<T>, d_norm_bias: &mut DeviceBuffer<T>,
3461 epsilon: f64,
3462 norm_mean_var_desc: &TensorDescriptor,
3463 saved_mean: &DeviceBuffer<T>, saved_inv_var: &DeviceBuffer<T>,
3464 activation: &ActivationDescriptor,
3465 workspace: &mut DeviceBuffer<u8>, reserve: &mut DeviceBuffer<u8>,
3466 group_count: i32,
3467) -> Result<()> {
3468 let c = cudnn()?;
3469 let cu = c.cudnn_normalization_backward()?;
3470 check(unsafe {
3471 cu(
3472 handle.handle, mode.raw(), ops.raw(), algo.raw(),
3473 &alpha_data as *const f32 as *const core::ffi::c_void,
3474 &beta_data as *const f32 as *const core::ffi::c_void,
3475 &alpha_param as *const f32 as *const core::ffi::c_void,
3476 &beta_param as *const f32 as *const core::ffi::c_void,
3477 x_desc.desc, x.as_raw().0 as *const core::ffi::c_void,
3478 y_desc.desc, y.as_raw().0 as *const core::ffi::c_void,
3479 dy_desc.desc, dy.as_raw().0 as *const core::ffi::c_void,
3480 dz_desc.desc, dz.as_raw().0 as *mut core::ffi::c_void,
3481 dx_desc.desc, dx.as_raw().0 as *mut core::ffi::c_void,
3482 d_norm_scale_bias_desc.desc,
3483 norm_scale.as_raw().0 as *const core::ffi::c_void,
3484 norm_bias.as_raw().0 as *const core::ffi::c_void,
3485 d_norm_scale.as_raw().0 as *mut core::ffi::c_void,
3486 d_norm_bias.as_raw().0 as *mut core::ffi::c_void,
3487 epsilon,
3488 norm_mean_var_desc.desc,
3489 saved_mean.as_raw().0 as *const core::ffi::c_void,
3490 saved_inv_var.as_raw().0 as *const core::ffi::c_void,
3491 activation.desc,
3492 workspace.as_raw().0 as *mut core::ffi::c_void, workspace.byte_size(),
3493 reserve.as_raw().0 as *mut core::ffi::c_void, reserve.byte_size(),
3494 group_count,
3495 )
3496 })
3497}
3498
3499#[allow(clippy::too_many_arguments)]
3514pub unsafe fn get_multi_head_attn_weights(
3515 handle: &Handle,
3516 attn: &AttnDescriptor,
3517 w_kind: i32,
3518 weight_size_in_bytes: usize,
3519 weights: *const core::ffi::c_void,
3520 w_desc: &TensorDescriptor,
3521) -> Result<*mut core::ffi::c_void> { unsafe {
3522 let c = cudnn()?;
3523 let f = c.cudnn_get_multi_head_attn_weights()?;
3524 let mut addr: *mut core::ffi::c_void = core::ptr::null_mut();
3525 check(f(
3526 handle.handle, attn.desc, w_kind, weight_size_in_bytes, weights,
3527 w_desc.desc, &mut addr,
3528 ))?;
3529 Ok(addr)
3530}}
3531
3532#[allow(clippy::too_many_arguments)]
3541pub unsafe fn multi_head_attn_forward(
3542 handle: &Handle,
3543 attn: &AttnDescriptor,
3544 curr_idx: i32,
3545 lo_win_idx: &[i32],
3546 hi_win_idx: &[i32],
3547 dev_seq_lengths_qo: *const i32,
3548 dev_seq_lengths_kv: *const i32,
3549 q_desc: &SeqDataDescriptor, queries: *const core::ffi::c_void,
3550 residuals: *const core::ffi::c_void,
3551 k_desc: &SeqDataDescriptor, keys: *const core::ffi::c_void,
3552 v_desc: &SeqDataDescriptor, values: *const core::ffi::c_void,
3553 o_desc: &SeqDataDescriptor, out: *mut core::ffi::c_void,
3554 weights: &DeviceBuffer<u8>,
3555 work_space: &mut DeviceBuffer<u8>,
3556 reserve_space: &mut DeviceBuffer<u8>,
3557) -> Result<()> { unsafe {
3558 let c = cudnn()?;
3559 let f = c.cudnn_multi_head_attn_forward()?;
3560 check(f(
3561 handle.handle, attn.desc,
3562 curr_idx, lo_win_idx.as_ptr(), hi_win_idx.as_ptr(),
3563 dev_seq_lengths_qo, dev_seq_lengths_kv,
3564 q_desc.desc, queries, residuals,
3565 k_desc.desc, keys,
3566 v_desc.desc, values,
3567 o_desc.desc, out,
3568 weights.byte_size(), weights.as_raw().0 as *const core::ffi::c_void,
3569 work_space.byte_size(), work_space.as_raw().0 as *mut core::ffi::c_void,
3570 reserve_space.byte_size(), reserve_space.as_raw().0 as *mut core::ffi::c_void,
3571 ))
3572}}
3573
3574#[allow(clippy::too_many_arguments)]
3579pub unsafe fn multi_head_attn_backward_data(
3580 handle: &Handle,
3581 attn: &AttnDescriptor,
3582 lo_win_idx: &[i32],
3583 hi_win_idx: &[i32],
3584 dev_seq_lengths_dqdo: *const i32,
3585 dev_seq_lengths_dkdv: *const i32,
3586 do_desc: &SeqDataDescriptor, dout: *const core::ffi::c_void,
3587 dq_desc: &SeqDataDescriptor, dqueries: *mut core::ffi::c_void,
3588 queries: *const core::ffi::c_void,
3589 dk_desc: &SeqDataDescriptor, dkeys: *mut core::ffi::c_void,
3590 keys: *const core::ffi::c_void,
3591 dv_desc: &SeqDataDescriptor, dvalues: *mut core::ffi::c_void,
3592 values: *const core::ffi::c_void,
3593 weights: &DeviceBuffer<u8>,
3594 work_space: &mut DeviceBuffer<u8>,
3595 reserve_space: &mut DeviceBuffer<u8>,
3596) -> Result<()> { unsafe {
3597 let c = cudnn()?;
3598 let f = c.cudnn_multi_head_attn_backward_data()?;
3599 check(f(
3600 handle.handle, attn.desc,
3601 lo_win_idx.as_ptr(), hi_win_idx.as_ptr(),
3602 dev_seq_lengths_dqdo, dev_seq_lengths_dkdv,
3603 do_desc.desc, dout,
3604 dq_desc.desc, dqueries, queries,
3605 dk_desc.desc, dkeys, keys,
3606 dv_desc.desc, dvalues, values,
3607 weights.byte_size(), weights.as_raw().0 as *const core::ffi::c_void,
3608 work_space.byte_size(), work_space.as_raw().0 as *mut core::ffi::c_void,
3609 reserve_space.byte_size(), reserve_space.as_raw().0 as *mut core::ffi::c_void,
3610 ))
3611}}
3612
3613#[allow(clippy::too_many_arguments)]
3620pub unsafe fn multi_head_attn_backward_weights(
3621 handle: &Handle,
3622 attn: &AttnDescriptor,
3623 add_grad: bool,
3624 q_desc: &SeqDataDescriptor, queries: *const core::ffi::c_void,
3625 k_desc: &SeqDataDescriptor, keys: *const core::ffi::c_void,
3626 v_desc: &SeqDataDescriptor, values: *const core::ffi::c_void,
3627 do_desc: &SeqDataDescriptor, dout: *const core::ffi::c_void,
3628 weights: &DeviceBuffer<u8>,
3629 dweights: &mut DeviceBuffer<u8>,
3630 work_space: &mut DeviceBuffer<u8>,
3631 reserve_space: &mut DeviceBuffer<u8>,
3632) -> Result<()> { unsafe {
3633 let c = cudnn()?;
3634 let f = c.cudnn_multi_head_attn_backward_weights()?;
3635 check(f(
3636 handle.handle, attn.desc, add_grad as core::ffi::c_int,
3637 q_desc.desc, queries,
3638 k_desc.desc, keys,
3639 v_desc.desc, values,
3640 do_desc.desc, dout,
3641 weights.byte_size(), weights.as_raw().0 as *const core::ffi::c_void,
3642 dweights.as_raw().0 as *mut core::ffi::c_void,
3643 work_space.byte_size(), work_space.as_raw().0 as *mut core::ffi::c_void,
3644 reserve_space.byte_size(), reserve_space.as_raw().0 as *mut core::ffi::c_void,
3645 ))
3646}}
3647
3648#[allow(clippy::too_many_arguments)]
3665pub unsafe fn rnn_forward(
3666 handle: &Handle,
3667 rnn: &RnnDescriptor,
3668 fwd_mode: i32,
3669 dev_seq_lengths: *const i32,
3670 x_desc: &RnnDataDescriptor, x: *const core::ffi::c_void,
3671 y_desc: &RnnDataDescriptor, y: *mut core::ffi::c_void,
3672 h_desc: &TensorDescriptor,
3673 hx: *const core::ffi::c_void,
3674 hy: *mut core::ffi::c_void,
3675 c_desc: &TensorDescriptor,
3676 cx: *const core::ffi::c_void,
3677 cy: *mut core::ffi::c_void,
3678 weight_space: &DeviceBuffer<u8>,
3679 work_space: &mut DeviceBuffer<u8>,
3680 reserve_space: &mut DeviceBuffer<u8>,
3681) -> Result<()> { unsafe {
3682 let c = cudnn()?;
3683 let f = c.cudnn_rnn_forward()?;
3684 check(f(
3685 handle.handle, rnn.desc, fwd_mode, dev_seq_lengths,
3686 x_desc.desc, x,
3687 y_desc.desc, y,
3688 h_desc.desc, hx, hy,
3689 c_desc.desc, cx, cy,
3690 weight_space.byte_size(), weight_space.as_raw().0 as *const core::ffi::c_void,
3691 work_space.byte_size(), work_space.as_raw().0 as *mut core::ffi::c_void,
3692 reserve_space.byte_size(), reserve_space.as_raw().0 as *mut core::ffi::c_void,
3693 ))
3694}}
3695
3696#[allow(clippy::too_many_arguments)]
3703pub unsafe fn rnn_backward_data_v8(
3704 handle: &Handle,
3705 rnn: &RnnDescriptor,
3706 dev_seq_lengths: *const i32,
3707 y_desc: &RnnDataDescriptor,
3708 y: *const core::ffi::c_void,
3709 dy: *const core::ffi::c_void,
3710 x_desc: &RnnDataDescriptor,
3711 dx: *mut core::ffi::c_void,
3712 h_desc: &TensorDescriptor,
3713 hx: *const core::ffi::c_void,
3714 dhy: *const core::ffi::c_void,
3715 dhx: *mut core::ffi::c_void,
3716 c_desc: &TensorDescriptor,
3717 cx: *const core::ffi::c_void,
3718 dcy: *const core::ffi::c_void,
3719 dcx: *mut core::ffi::c_void,
3720 weight_space: &DeviceBuffer<u8>,
3721 work_space: &mut DeviceBuffer<u8>,
3722 reserve_space: &mut DeviceBuffer<u8>,
3723) -> Result<()> { unsafe {
3724 let c = cudnn()?;
3725 let f = c.cudnn_rnn_backward_data_v8()?;
3726 check(f(
3727 handle.handle, rnn.desc, dev_seq_lengths,
3728 y_desc.desc, y, dy,
3729 x_desc.desc, dx,
3730 h_desc.desc, hx, dhy, dhx,
3731 c_desc.desc, cx, dcy, dcx,
3732 weight_space.byte_size(), weight_space.as_raw().0 as *const core::ffi::c_void,
3733 work_space.byte_size(), work_space.as_raw().0 as *mut core::ffi::c_void,
3734 reserve_space.byte_size(), reserve_space.as_raw().0 as *mut core::ffi::c_void,
3735 ))
3736}}
3737
3738#[allow(clippy::too_many_arguments)]
3745pub unsafe fn rnn_backward_weights_v8(
3746 handle: &Handle,
3747 rnn: &RnnDescriptor,
3748 add_grad: bool,
3749 dev_seq_lengths: *const i32,
3750 x_desc: &RnnDataDescriptor, x: *const core::ffi::c_void,
3751 h_desc: &TensorDescriptor, hx: *const core::ffi::c_void,
3752 y_desc: &RnnDataDescriptor, y: *const core::ffi::c_void,
3753 dweight_space: &mut DeviceBuffer<u8>,
3754 work_space: &mut DeviceBuffer<u8>,
3755 reserve_space: &mut DeviceBuffer<u8>,
3756) -> Result<()> { unsafe {
3757 let c = cudnn()?;
3758 let f = c.cudnn_rnn_backward_weights_v8()?;
3759 check(f(
3760 handle.handle, rnn.desc, add_grad as core::ffi::c_int, dev_seq_lengths,
3761 x_desc.desc, x,
3762 h_desc.desc, hx,
3763 y_desc.desc, y,
3764 dweight_space.byte_size(), dweight_space.as_raw().0 as *mut core::ffi::c_void,
3765 work_space.byte_size(), work_space.as_raw().0 as *mut core::ffi::c_void,
3766 reserve_space.byte_size(), reserve_space.as_raw().0 as *mut core::ffi::c_void,
3767 ))
3768}}