1#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals)]
12#![warn(missing_debug_implementations)]
13
14use core::ffi::{c_double, c_int, c_void};
15use std::path::PathBuf;
16use std::sync::OnceLock;
17
18use baracuda_core::{Library, LoaderError};
19use baracuda_cuda_sys::runtime::cudaStream_t;
20use baracuda_types::CudaStatus;
21
22pub type cudnnHandle_t = *mut c_void;
24pub type cudnnTensorDescriptor_t = *mut c_void;
26pub type cudnnActivationDescriptor_t = *mut c_void;
28pub type cudnnFilterDescriptor_t = *mut c_void;
30pub type cudnnConvolutionDescriptor_t = *mut c_void;
32pub type cudnnPoolingDescriptor_t = *mut c_void;
34pub type cudnnLRNDescriptor_t = *mut c_void;
36pub type cudnnOpTensorDescriptor_t = *mut c_void;
38pub type cudnnReduceTensorDescriptor_t = *mut c_void;
40pub type cudnnDropoutDescriptor_t = *mut c_void;
42pub type cudnnCTCLossDescriptor_t = *mut c_void;
44pub type cudnnRNNDescriptor_t = *mut c_void;
46pub type cudnnRNNDataDescriptor_t = *mut c_void;
48pub type cudnnBackendDescriptor_t = *mut c_void;
50
51#[repr(i32)]
53#[derive(Copy, Clone, Debug, Eq, PartialEq)]
54pub enum cudnnConvolutionFwdAlgo_t {
55 ImplicitGemm = 0,
57 ImplicitPrecompGemm = 1,
59 Gemm = 2,
61 Direct = 3,
63 Fft = 4,
65 FftTiling = 5,
67 Winograd = 6,
69 WinogradNonfused = 7,
71}
72
73#[repr(i32)]
75#[derive(Copy, Clone, Debug, Eq, PartialEq)]
76pub enum cudnnConvolutionMode_t {
77 Convolution = 0,
79 CrossCorrelation = 1,
81}
82
83#[repr(i32)]
85#[derive(Copy, Clone, Debug, Eq, PartialEq)]
86pub enum cudnnDataType_t {
87 Float = 0,
89 Double = 1,
91 Half = 2,
93 Int8 = 3,
95 Int32 = 4,
97 Int8x4 = 5,
99 Uint8 = 6,
101 Uint8x4 = 7,
103 Int8x32 = 8,
105 BFloat16 = 9,
107}
108
109#[repr(i32)]
111#[derive(Copy, Clone, Debug, Eq, PartialEq)]
112pub enum cudnnTensorFormat_t {
113 Nchw = 0,
115 Nhwc = 1,
117 NchwVectC = 2,
119}
120
121#[repr(i32)]
123#[derive(Copy, Clone, Debug, Eq, PartialEq)]
124pub enum cudnnActivationMode_t {
125 Sigmoid = 0,
127 Relu = 1,
129 Tanh = 2,
131 ClippedRelu = 3,
133 Elu = 4,
135 Identity = 5,
137 Swish = 6,
139}
140
141#[repr(i32)]
143#[derive(Copy, Clone, Debug, Eq, PartialEq)]
144pub enum cudnnNanPropagation_t {
145 NotPropagateNan = 0,
147 PropagateNan = 1,
149}
150
151#[repr(i32)]
153#[derive(Copy, Clone, Debug, Eq, PartialEq)]
154pub enum cudnnPoolingMode_t {
155 Max = 0,
157 AverageCountIncludePadding = 1,
159 AverageCountExcludePadding = 2,
161 MaxDeterministic = 3,
163}
164
165#[repr(i32)]
167#[derive(Copy, Clone, Debug, Eq, PartialEq)]
168pub enum cudnnSoftmaxAlgorithm_t {
169 Fast = 0,
171 Accurate = 1,
173 Log = 2,
175}
176
177#[repr(i32)]
179#[derive(Copy, Clone, Debug, Eq, PartialEq)]
180pub enum cudnnSoftmaxMode_t {
181 Instance = 0,
183 Channel = 1,
185}
186
187#[repr(i32)]
189#[derive(Copy, Clone, Debug, Eq, PartialEq)]
190pub enum cudnnBatchNormMode_t {
191 PerActivation = 0,
193 Spatial = 1,
195 SpatialPersistent = 2,
197}
198
199#[repr(i32)]
201#[derive(Copy, Clone, Debug, Eq, PartialEq)]
202pub enum cudnnOpTensorOp_t {
203 Add = 0,
205 Mul = 1,
207 Min = 2,
209 Max = 3,
211 Sqrt = 4,
213 Not = 5,
215}
216
217#[repr(i32)]
219#[derive(Copy, Clone, Debug, Eq, PartialEq)]
220pub enum cudnnReduceTensorOp_t {
221 Add = 0,
223 Mul = 1,
225 Min = 2,
227 Max = 3,
229 Amax = 4,
231 Avg = 5,
233 Norm1 = 6,
235 Norm2 = 7,
237 MulNoZeros = 8,
239}
240
241#[repr(i32)]
243#[derive(Copy, Clone, Debug, Eq, PartialEq)]
244pub enum cudnnReduceTensorIndices_t {
245 NoIndices = 0,
247 FlattenedIndices = 1,
249}
250
251#[repr(i32)]
253#[derive(Copy, Clone, Debug, Eq, PartialEq)]
254pub enum cudnnIndicesType_t {
255 U32 = 0,
257 U64 = 1,
259 U16 = 2,
261 U8 = 3,
263}
264
265#[repr(i32)]
267#[derive(Copy, Clone, Debug, Eq, PartialEq)]
268pub enum cudnnRNNMode_t {
269 ReluRnn = 0,
271 TanhRnn = 1,
273 Lstm = 2,
275 Gru = 3,
277}
278
279#[repr(i32)]
281#[derive(Copy, Clone, Debug, Eq, PartialEq)]
282pub enum cudnnDirectionMode_t {
283 Unidirectional = 0,
285 Bidirectional = 1,
287}
288
289#[repr(i32)]
291#[derive(Copy, Clone, Debug, Eq, PartialEq)]
292pub enum cudnnRNNInputMode_t {
293 LinearInput = 0,
295 SkipInput = 1,
297}
298
299#[repr(i32)]
301#[derive(Copy, Clone, Debug, Eq, PartialEq)]
302pub enum cudnnRNNAlgo_t {
303 Standard = 0,
305 PersistStatic = 1,
307 PersistDynamic = 2,
309 PersistStaticSmallH = 3,
311}
312
313#[repr(i32)]
315#[derive(Copy, Clone, Debug, Eq, PartialEq)]
316pub enum cudnnConvolutionBwdDataAlgo_t {
317 Algo0 = 0,
319 Algo1 = 1,
321 Fft = 2,
323 FftTiling = 3,
325 Winograd = 4,
327 WinogradNonfused = 5,
329}
330
331#[repr(i32)]
333#[derive(Copy, Clone, Debug, Eq, PartialEq)]
334pub enum cudnnConvolutionBwdFilterAlgo_t {
335 Algo0 = 0,
337 Algo1 = 1,
339 Fft = 2,
341 Algo3 = 3,
343 Winograd = 4,
345 WinogradNonfused = 5,
347 FftTiling = 6,
349}
350
351#[repr(i32)]
353#[derive(Copy, Clone, Debug, Eq, PartialEq)]
354pub enum cudnnBackendDescriptorType_t {
355 PointwiseDescriptor = 0,
357 ConvolutionDescriptor = 1,
359 EngineDescriptor = 2,
361 EngineCfgDescriptor = 3,
363 ExecutionPlanDescriptor = 4,
365 IntermediateInfoDescriptor = 5,
367 KnobChoiceDescriptor = 6,
369 KnobInfoDescriptor = 7,
371 LayoutInfoDescriptor = 8,
373 OperationConvolutionForwardDescriptor = 9,
375 OperationConvolutionBackwardFilterDescriptor = 10,
377 OperationConvolutionBackwardDataDescriptor = 11,
379 OperationPointwiseDescriptor = 12,
381 OperationGenStatsDescriptor = 13,
383 OperationReductionDescriptor = 14,
385 OperationBnFinalizeStatisticsDescriptor = 15,
387 OperationGraphDescriptor = 16,
389 VariantPackDescriptor = 17,
391 TensorDescriptor = 18,
393 MatmulDescriptor = 19,
395 OperationMatmulDescriptor = 20,
397 OperationBnBwdWeightsDescriptor = 21,
399 ResampleDescriptor = 22,
401 OperationResampleFwdDescriptor = 23,
403 OperationResampleBwdDescriptor = 24,
405 OperationConcatDescriptor = 25,
407 OperationSignalDescriptor = 26,
409 OperationNormForwardDescriptor = 27,
411 OperationNormBackwardDescriptor = 28,
413 OperationRngDescriptor = 30,
415 RngDescriptor = 31,
417}
418
419#[repr(i32)]
421#[derive(Copy, Clone, Debug, Eq, PartialEq)]
422pub enum cudnnBackendAttributeName_t {
423 PointwiseMode = 0,
426 PointwiseMathPrec = 1,
428 PointwiseNanPropagation = 2,
430 PointwiseReluLowerClip = 3,
432 PointwiseReluUpperClip = 4,
434 PointwiseEluAlpha = 5,
436 TensorUniqueId = 100,
439 TensorDataType = 101,
441 TensorByteAlignment = 102,
443 TensorDimensions = 103,
445 TensorStrides = 104,
447 ConvolutionCompType = 200,
450 ConvolutionConvMode = 201,
452 ConvolutionDilations = 202,
454 ConvolutionFilterStrides = 203,
456 ConvolutionPrePaddings = 204,
458 ConvolutionPostPaddings = 205,
460 ConvolutionSpatialDims = 206,
462 OperationGraphHandle = 500,
465 OperationGraphOps = 501,
467 ExecutionPlanHandle = 600,
470 ExecutionPlanEngineConfig = 601,
472 ExecutionPlanWorkspaceSize = 602,
474}
475
476#[repr(i32)]
478#[derive(Copy, Clone, Debug, Eq, PartialEq)]
479pub enum cudnnBackendAttributeType_t {
480 Handle = 0,
482 DataType = 1,
484 Boolean = 2,
486 Int64 = 3,
488 FloatValue = 4,
490 DoubleValue = 5,
492 PointwiseMode = 6,
494 ConvolutionMode = 7,
496 HeurMode = 8,
498 KnobType = 9,
500 NanPropagation = 10,
502 NumericalNote = 11,
504 LayoutType = 12,
506 AttribName = 13,
508 PointerT = 14,
510 BackendDescriptor = 15,
512 GenstatsMode = 16,
514 BnFinalizeStatsMode = 17,
516 ReductionOperatorType = 18,
518 BehaviorNote = 19,
520 TensorReorderingMode = 20,
522 ResampleMode = 21,
524 PaddingMode = 22,
526 IntArray = 23,
528 RngDistribution = 24,
530 NormMode = 25,
532 NormFwdPhase = 26,
534 RngNormal = 27,
536 RngUniform = 28,
538}
539
540#[repr(i32)]
544#[derive(Copy, Clone, Debug, Eq, PartialEq)]
545pub enum cudnnMathType_t {
546 DefaultMath = 0,
548 TensorOpMath = 1,
550 TensorOpMathAllowConversion = 2,
552 FmaMath = 3,
554}
555
556#[repr(i32)]
558#[derive(Copy, Clone, Debug, Eq, PartialEq)]
559pub enum cudnnReorderType_t {
560 DefaultReorder = 0,
562 NoReorder = 1,
564}
565
566#[repr(i32)]
568#[derive(Copy, Clone, Debug, Eq, PartialEq)]
569pub enum cudnnNormMode_t {
570 PerActivation = 0,
572 PerChannel = 1,
574}
575
576#[repr(i32)]
578#[derive(Copy, Clone, Debug, Eq, PartialEq)]
579pub enum cudnnNormAlgo_t {
580 Standard = 0,
582 Persist = 1,
584}
585
586#[repr(i32)]
588#[derive(Copy, Clone, Debug, Eq, PartialEq)]
589pub enum cudnnNormOps_t {
590 Norm = 0,
592 NormActivation = 1,
594 NormAddActivation = 2,
596}
597
598#[repr(i32)]
600#[derive(Copy, Clone, Debug, Eq, PartialEq)]
601pub enum cudnnBatchNormOps_t {
602 Bn = 0,
604 BnActivation = 1,
606 BnAddActivation = 2,
608}
609
610#[repr(i32)]
613#[derive(Copy, Clone, Debug, Eq, PartialEq)]
614pub enum cudnnDeterminism_t {
615 NonDeterministic = 0,
617 Deterministic = 1,
619}
620
621#[repr(C)]
624#[derive(Copy, Clone, Debug)]
625pub struct cudnnConvolutionFwdAlgoPerf_t {
626 pub algo: cudnnConvolutionFwdAlgo_t,
628 pub status: cudnnStatus_t,
630 pub time: f32,
632 pub memory: usize,
634 pub determinism: cudnnDeterminism_t,
636 pub math_type: cudnnMathType_t,
638 pub reserved: [c_int; 3],
640}
641
642#[repr(C)]
644#[derive(Copy, Clone, Debug)]
645pub struct cudnnConvolutionBwdDataAlgoPerf_t {
646 pub algo: cudnnConvolutionBwdDataAlgo_t,
648 pub status: cudnnStatus_t,
650 pub time: f32,
652 pub memory: usize,
654 pub determinism: cudnnDeterminism_t,
656 pub math_type: cudnnMathType_t,
658 pub reserved: [c_int; 3],
660}
661
662#[repr(C)]
664#[derive(Copy, Clone, Debug)]
665pub struct cudnnConvolutionBwdFilterAlgoPerf_t {
666 pub algo: cudnnConvolutionBwdFilterAlgo_t,
668 pub status: cudnnStatus_t,
670 pub time: f32,
672 pub memory: usize,
674 pub determinism: cudnnDeterminism_t,
676 pub math_type: cudnnMathType_t,
678 pub reserved: [c_int; 3],
680}
681
682pub type cudnnTensorTransformDescriptor_t = *mut c_void;
686pub type cudnnAttnDescriptor_t = *mut c_void;
688pub type cudnnSeqDataDescriptor_t = *mut c_void;
690
691#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
695#[repr(transparent)]
696pub struct cudnnStatus_t(pub i32);
697
698impl cudnnStatus_t {
699 pub const SUCCESS: Self = Self(0);
701 pub const NOT_INITIALIZED: Self = Self(1);
703 pub const ALLOC_FAILED: Self = Self(2);
705 pub const BAD_PARAM: Self = Self(3);
707 pub const INTERNAL_ERROR: Self = Self(4);
709 pub const INVALID_VALUE: Self = Self(5);
711 pub const ARCH_MISMATCH: Self = Self(6);
713 pub const MAPPING_ERROR: Self = Self(7);
715 pub const EXECUTION_FAILED: Self = Self(8);
717 pub const NOT_SUPPORTED: Self = Self(9);
719 pub const LICENSE_ERROR: Self = Self(10);
721
722 pub const fn is_success(self) -> bool {
724 self.0 == 0
725 }
726}
727
728impl CudaStatus for cudnnStatus_t {
729 fn code(self) -> i32 {
730 self.0
731 }
732 fn name(self) -> &'static str {
733 match self.0 {
734 0 => "CUDNN_STATUS_SUCCESS",
735 1 => "CUDNN_STATUS_NOT_INITIALIZED",
736 2 => "CUDNN_STATUS_ALLOC_FAILED",
737 3 => "CUDNN_STATUS_BAD_PARAM",
738 4 => "CUDNN_STATUS_INTERNAL_ERROR",
739 5 => "CUDNN_STATUS_INVALID_VALUE",
740 6 => "CUDNN_STATUS_ARCH_MISMATCH",
741 8 => "CUDNN_STATUS_EXECUTION_FAILED",
742 9 => "CUDNN_STATUS_NOT_SUPPORTED",
743 _ => "CUDNN_STATUS_UNRECOGNIZED",
744 }
745 }
746 fn description(self) -> &'static str {
747 match self.0 {
748 0 => "success",
749 1 => "cuDNN not initialized",
750 3 => "bad parameter",
751 9 => "operation not supported on this device/version",
752 _ => "unrecognized cuDNN status code",
753 }
754 }
755 fn is_success(self) -> bool {
756 cudnnStatus_t::is_success(self)
757 }
758 fn library(self) -> &'static str {
759 "cudnn"
760 }
761}
762
763pub type PFN_cudnnCreate = unsafe extern "C" fn(handle: *mut cudnnHandle_t) -> cudnnStatus_t;
767pub type PFN_cudnnDestroy = unsafe extern "C" fn(handle: cudnnHandle_t) -> cudnnStatus_t;
769pub type PFN_cudnnSetStream =
771 unsafe extern "C" fn(handle: cudnnHandle_t, stream: cudaStream_t) -> cudnnStatus_t;
772pub type PFN_cudnnGetVersion = unsafe extern "C" fn() -> usize;
774
775pub type PFN_cudnnCreateTensorDescriptor =
777 unsafe extern "C" fn(desc: *mut cudnnTensorDescriptor_t) -> cudnnStatus_t;
778pub type PFN_cudnnDestroyTensorDescriptor =
780 unsafe extern "C" fn(desc: cudnnTensorDescriptor_t) -> cudnnStatus_t;
781pub type PFN_cudnnSetTensor4dDescriptor = unsafe extern "C" fn(
783 desc: cudnnTensorDescriptor_t,
784 format: cudnnTensorFormat_t,
785 data_type: cudnnDataType_t,
786 n: c_int,
787 c: c_int,
788 h: c_int,
789 w: c_int,
790) -> cudnnStatus_t;
791
792pub type PFN_cudnnCreateActivationDescriptor =
794 unsafe extern "C" fn(desc: *mut cudnnActivationDescriptor_t) -> cudnnStatus_t;
795pub type PFN_cudnnDestroyActivationDescriptor =
797 unsafe extern "C" fn(desc: cudnnActivationDescriptor_t) -> cudnnStatus_t;
798pub type PFN_cudnnSetActivationDescriptor = unsafe extern "C" fn(
800 desc: cudnnActivationDescriptor_t,
801 mode: cudnnActivationMode_t,
802 nan_prop: cudnnNanPropagation_t,
803 coef: c_double,
804) -> cudnnStatus_t;
805
806pub type PFN_cudnnActivationForward = unsafe extern "C" fn(
808 handle: cudnnHandle_t,
809 activation_desc: cudnnActivationDescriptor_t,
810 alpha: *const c_void,
811 x_desc: cudnnTensorDescriptor_t,
812 x: *const c_void,
813 beta: *const c_void,
814 y_desc: cudnnTensorDescriptor_t,
815 y: *mut c_void,
816) -> cudnnStatus_t;
817
818pub type PFN_cudnnCreateFilterDescriptor =
822 unsafe extern "C" fn(desc: *mut cudnnFilterDescriptor_t) -> cudnnStatus_t;
823pub type PFN_cudnnDestroyFilterDescriptor =
825 unsafe extern "C" fn(desc: cudnnFilterDescriptor_t) -> cudnnStatus_t;
826pub type PFN_cudnnSetFilter4dDescriptor = unsafe extern "C" fn(
828 desc: cudnnFilterDescriptor_t,
829 data_type: cudnnDataType_t,
830 format: cudnnTensorFormat_t,
831 k: c_int,
832 c: c_int,
833 h: c_int,
834 w: c_int,
835) -> cudnnStatus_t;
836
837pub type PFN_cudnnCreateConvolutionDescriptor =
839 unsafe extern "C" fn(desc: *mut cudnnConvolutionDescriptor_t) -> cudnnStatus_t;
840pub type PFN_cudnnDestroyConvolutionDescriptor =
842 unsafe extern "C" fn(desc: cudnnConvolutionDescriptor_t) -> cudnnStatus_t;
843pub type PFN_cudnnSetConvolution2dDescriptor = unsafe extern "C" fn(
845 desc: cudnnConvolutionDescriptor_t,
846 pad_h: c_int,
847 pad_w: c_int,
848 u: c_int,
849 v: c_int,
850 dilation_h: c_int,
851 dilation_w: c_int,
852 mode: cudnnConvolutionMode_t,
853 compute_type: cudnnDataType_t,
854) -> cudnnStatus_t;
855
856pub type PFN_cudnnGetConvolution2dForwardOutputDim = unsafe extern "C" fn(
858 conv_desc: cudnnConvolutionDescriptor_t,
859 input_desc: cudnnTensorDescriptor_t,
860 filter_desc: cudnnFilterDescriptor_t,
861 n: *mut c_int,
862 c: *mut c_int,
863 h: *mut c_int,
864 w: *mut c_int,
865) -> cudnnStatus_t;
866
867pub type PFN_cudnnGetConvolutionForwardWorkspaceSize = unsafe extern "C" fn(
869 handle: cudnnHandle_t,
870 x_desc: cudnnTensorDescriptor_t,
871 w_desc: cudnnFilterDescriptor_t,
872 conv_desc: cudnnConvolutionDescriptor_t,
873 y_desc: cudnnTensorDescriptor_t,
874 algo: cudnnConvolutionFwdAlgo_t,
875 size_in_bytes: *mut usize,
876) -> cudnnStatus_t;
877
878pub type PFN_cudnnConvolutionForward = unsafe extern "C" fn(
880 handle: cudnnHandle_t,
881 alpha: *const c_void,
882 x_desc: cudnnTensorDescriptor_t,
883 x: *const c_void,
884 w_desc: cudnnFilterDescriptor_t,
885 w: *const c_void,
886 conv_desc: cudnnConvolutionDescriptor_t,
887 algo: cudnnConvolutionFwdAlgo_t,
888 workspace: *mut c_void,
889 workspace_size: usize,
890 beta: *const c_void,
891 y_desc: cudnnTensorDescriptor_t,
892 y: *mut c_void,
893) -> cudnnStatus_t;
894
895pub type PFN_cudnnConvolutionBackwardData = unsafe extern "C" fn(
897 handle: cudnnHandle_t,
898 alpha: *const c_void,
899 w_desc: cudnnFilterDescriptor_t,
900 w: *const c_void,
901 dy_desc: cudnnTensorDescriptor_t,
902 dy: *const c_void,
903 conv_desc: cudnnConvolutionDescriptor_t,
904 algo: cudnnConvolutionBwdDataAlgo_t,
905 workspace: *mut c_void,
906 workspace_size: usize,
907 beta: *const c_void,
908 dx_desc: cudnnTensorDescriptor_t,
909 dx: *mut c_void,
910) -> cudnnStatus_t;
911
912pub type PFN_cudnnConvolutionBackwardFilter = unsafe extern "C" fn(
914 handle: cudnnHandle_t,
915 alpha: *const c_void,
916 x_desc: cudnnTensorDescriptor_t,
917 x: *const c_void,
918 dy_desc: cudnnTensorDescriptor_t,
919 dy: *const c_void,
920 conv_desc: cudnnConvolutionDescriptor_t,
921 algo: cudnnConvolutionBwdFilterAlgo_t,
922 workspace: *mut c_void,
923 workspace_size: usize,
924 beta: *const c_void,
925 dw_desc: cudnnFilterDescriptor_t,
926 dw: *mut c_void,
927) -> cudnnStatus_t;
928
929pub type PFN_cudnnConvolutionBackwardBias = unsafe extern "C" fn(
931 handle: cudnnHandle_t,
932 alpha: *const c_void,
933 dy_desc: cudnnTensorDescriptor_t,
934 dy: *const c_void,
935 beta: *const c_void,
936 db_desc: cudnnTensorDescriptor_t,
937 db: *mut c_void,
938) -> cudnnStatus_t;
939
940pub type PFN_cudnnGetConvolutionBackwardDataWorkspaceSize = unsafe extern "C" fn(
942 handle: cudnnHandle_t,
943 w_desc: cudnnFilterDescriptor_t,
944 dy_desc: cudnnTensorDescriptor_t,
945 conv_desc: cudnnConvolutionDescriptor_t,
946 dx_desc: cudnnTensorDescriptor_t,
947 algo: cudnnConvolutionBwdDataAlgo_t,
948 size_in_bytes: *mut usize,
949) -> cudnnStatus_t;
950
951pub type PFN_cudnnGetConvolutionBackwardFilterWorkspaceSize = unsafe extern "C" fn(
953 handle: cudnnHandle_t,
954 x_desc: cudnnTensorDescriptor_t,
955 dy_desc: cudnnTensorDescriptor_t,
956 conv_desc: cudnnConvolutionDescriptor_t,
957 dw_desc: cudnnFilterDescriptor_t,
958 algo: cudnnConvolutionBwdFilterAlgo_t,
959 size_in_bytes: *mut usize,
960) -> cudnnStatus_t;
961
962pub type PFN_cudnnCreatePoolingDescriptor =
966 unsafe extern "C" fn(desc: *mut cudnnPoolingDescriptor_t) -> cudnnStatus_t;
967pub type PFN_cudnnDestroyPoolingDescriptor =
969 unsafe extern "C" fn(desc: cudnnPoolingDescriptor_t) -> cudnnStatus_t;
970pub type PFN_cudnnSetPooling2dDescriptor = unsafe extern "C" fn(
972 desc: cudnnPoolingDescriptor_t,
973 mode: cudnnPoolingMode_t,
974 nan_prop: cudnnNanPropagation_t,
975 window_h: c_int,
976 window_w: c_int,
977 vertical_padding: c_int,
978 horizontal_padding: c_int,
979 vertical_stride: c_int,
980 horizontal_stride: c_int,
981) -> cudnnStatus_t;
982pub type PFN_cudnnPoolingForward = unsafe extern "C" fn(
984 handle: cudnnHandle_t,
985 pool_desc: cudnnPoolingDescriptor_t,
986 alpha: *const c_void,
987 x_desc: cudnnTensorDescriptor_t,
988 x: *const c_void,
989 beta: *const c_void,
990 y_desc: cudnnTensorDescriptor_t,
991 y: *mut c_void,
992) -> cudnnStatus_t;
993pub type PFN_cudnnPoolingBackward = unsafe extern "C" fn(
995 handle: cudnnHandle_t,
996 pool_desc: cudnnPoolingDescriptor_t,
997 alpha: *const c_void,
998 y_desc: cudnnTensorDescriptor_t,
999 y: *const c_void,
1000 dy_desc: cudnnTensorDescriptor_t,
1001 dy: *const c_void,
1002 x_desc: cudnnTensorDescriptor_t,
1003 x: *const c_void,
1004 beta: *const c_void,
1005 dx_desc: cudnnTensorDescriptor_t,
1006 dx: *mut c_void,
1007) -> cudnnStatus_t;
1008
1009pub type PFN_cudnnSoftmaxForward = unsafe extern "C" fn(
1013 handle: cudnnHandle_t,
1014 algo: cudnnSoftmaxAlgorithm_t,
1015 mode: cudnnSoftmaxMode_t,
1016 alpha: *const c_void,
1017 x_desc: cudnnTensorDescriptor_t,
1018 x: *const c_void,
1019 beta: *const c_void,
1020 y_desc: cudnnTensorDescriptor_t,
1021 y: *mut c_void,
1022) -> cudnnStatus_t;
1023
1024pub type PFN_cudnnSoftmaxBackward = unsafe extern "C" fn(
1026 handle: cudnnHandle_t,
1027 algo: cudnnSoftmaxAlgorithm_t,
1028 mode: cudnnSoftmaxMode_t,
1029 alpha: *const c_void,
1030 y_desc: cudnnTensorDescriptor_t,
1031 y: *const c_void,
1032 dy_desc: cudnnTensorDescriptor_t,
1033 dy: *const c_void,
1034 beta: *const c_void,
1035 dx_desc: cudnnTensorDescriptor_t,
1036 dx: *mut c_void,
1037) -> cudnnStatus_t;
1038
1039pub type PFN_cudnnBatchNormalizationForwardInference = unsafe extern "C" fn(
1043 handle: cudnnHandle_t,
1044 mode: cudnnBatchNormMode_t,
1045 alpha: *const c_void,
1046 beta: *const c_void,
1047 x_desc: cudnnTensorDescriptor_t,
1048 x: *const c_void,
1049 y_desc: cudnnTensorDescriptor_t,
1050 y: *mut c_void,
1051 bn_scale_bias_mean_var_desc: cudnnTensorDescriptor_t,
1052 bn_scale: *const c_void,
1053 bn_bias: *const c_void,
1054 estimated_mean: *const c_void,
1055 estimated_variance: *const c_void,
1056 epsilon: c_double,
1057) -> cudnnStatus_t;
1058
1059pub type PFN_cudnnBatchNormalizationForwardTraining = unsafe extern "C" fn(
1061 handle: cudnnHandle_t,
1062 mode: cudnnBatchNormMode_t,
1063 alpha: *const c_void,
1064 beta: *const c_void,
1065 x_desc: cudnnTensorDescriptor_t,
1066 x: *const c_void,
1067 y_desc: cudnnTensorDescriptor_t,
1068 y: *mut c_void,
1069 bn_scale_bias_mean_var_desc: cudnnTensorDescriptor_t,
1070 bn_scale: *const c_void,
1071 bn_bias: *const c_void,
1072 exponential_average_factor: c_double,
1073 result_running_mean: *mut c_void,
1074 result_running_variance: *mut c_void,
1075 epsilon: c_double,
1076 result_save_mean: *mut c_void,
1077 result_save_inv_variance: *mut c_void,
1078) -> cudnnStatus_t;
1079
1080pub type PFN_cudnnBatchNormalizationBackward = unsafe extern "C" fn(
1082 handle: cudnnHandle_t,
1083 mode: cudnnBatchNormMode_t,
1084 alpha_data_diff: *const c_void,
1085 beta_data_diff: *const c_void,
1086 alpha_param_diff: *const c_void,
1087 beta_param_diff: *const c_void,
1088 x_desc: cudnnTensorDescriptor_t,
1089 x: *const c_void,
1090 dy_desc: cudnnTensorDescriptor_t,
1091 dy: *const c_void,
1092 dx_desc: cudnnTensorDescriptor_t,
1093 dx: *mut c_void,
1094 bn_scale_bias_diff_desc: cudnnTensorDescriptor_t,
1095 bn_scale: *const c_void,
1096 bn_scale_result: *mut c_void,
1097 bn_bias_result: *mut c_void,
1098 epsilon: c_double,
1099 saved_mean: *const c_void,
1100 saved_inv_variance: *const c_void,
1101) -> cudnnStatus_t;
1102
1103pub type PFN_cudnnCreateOpTensorDescriptor =
1107 unsafe extern "C" fn(desc: *mut cudnnOpTensorDescriptor_t) -> cudnnStatus_t;
1108pub type PFN_cudnnDestroyOpTensorDescriptor =
1110 unsafe extern "C" fn(desc: cudnnOpTensorDescriptor_t) -> cudnnStatus_t;
1111pub type PFN_cudnnSetOpTensorDescriptor = unsafe extern "C" fn(
1113 desc: cudnnOpTensorDescriptor_t,
1114 op: cudnnOpTensorOp_t,
1115 compute_type: cudnnDataType_t,
1116 nan_prop: cudnnNanPropagation_t,
1117) -> cudnnStatus_t;
1118pub type PFN_cudnnOpTensor = unsafe extern "C" fn(
1120 handle: cudnnHandle_t,
1121 desc: cudnnOpTensorDescriptor_t,
1122 alpha1: *const c_void,
1123 a_desc: cudnnTensorDescriptor_t,
1124 a: *const c_void,
1125 alpha2: *const c_void,
1126 b_desc: cudnnTensorDescriptor_t,
1127 b: *const c_void,
1128 beta: *const c_void,
1129 c_desc: cudnnTensorDescriptor_t,
1130 c: *mut c_void,
1131) -> cudnnStatus_t;
1132
1133pub type PFN_cudnnCreateReduceTensorDescriptor =
1135 unsafe extern "C" fn(desc: *mut cudnnReduceTensorDescriptor_t) -> cudnnStatus_t;
1136pub type PFN_cudnnDestroyReduceTensorDescriptor =
1138 unsafe extern "C" fn(desc: cudnnReduceTensorDescriptor_t) -> cudnnStatus_t;
1139pub type PFN_cudnnSetReduceTensorDescriptor = unsafe extern "C" fn(
1141 desc: cudnnReduceTensorDescriptor_t,
1142 op: cudnnReduceTensorOp_t,
1143 compute_type: cudnnDataType_t,
1144 nan_prop: cudnnNanPropagation_t,
1145 indices: cudnnReduceTensorIndices_t,
1146 indices_type: cudnnIndicesType_t,
1147) -> cudnnStatus_t;
1148pub type PFN_cudnnGetReductionWorkspaceSize = unsafe extern "C" fn(
1150 handle: cudnnHandle_t,
1151 desc: cudnnReduceTensorDescriptor_t,
1152 a_desc: cudnnTensorDescriptor_t,
1153 c_desc: cudnnTensorDescriptor_t,
1154 workspace_size: *mut usize,
1155) -> cudnnStatus_t;
1156pub type PFN_cudnnReduceTensor = unsafe extern "C" fn(
1158 handle: cudnnHandle_t,
1159 desc: cudnnReduceTensorDescriptor_t,
1160 indices: *mut c_void,
1161 indices_size: usize,
1162 workspace: *mut c_void,
1163 workspace_size: usize,
1164 alpha: *const c_void,
1165 a_desc: cudnnTensorDescriptor_t,
1166 a: *const c_void,
1167 beta: *const c_void,
1168 c_desc: cudnnTensorDescriptor_t,
1169 c: *mut c_void,
1170) -> cudnnStatus_t;
1171
1172pub type PFN_cudnnAddTensor = unsafe extern "C" fn(
1174 handle: cudnnHandle_t,
1175 alpha: *const c_void,
1176 a_desc: cudnnTensorDescriptor_t,
1177 a: *const c_void,
1178 beta: *const c_void,
1179 c_desc: cudnnTensorDescriptor_t,
1180 c: *mut c_void,
1181) -> cudnnStatus_t;
1182
1183pub type PFN_cudnnTransformTensor = unsafe extern "C" fn(
1185 handle: cudnnHandle_t,
1186 alpha: *const c_void,
1187 x_desc: cudnnTensorDescriptor_t,
1188 x: *const c_void,
1189 beta: *const c_void,
1190 y_desc: cudnnTensorDescriptor_t,
1191 y: *mut c_void,
1192) -> cudnnStatus_t;
1193
1194pub type PFN_cudnnScaleTensor = unsafe extern "C" fn(
1196 handle: cudnnHandle_t,
1197 y_desc: cudnnTensorDescriptor_t,
1198 y: *mut c_void,
1199 alpha: *const c_void,
1200) -> cudnnStatus_t;
1201
1202pub type PFN_cudnnSetTensor = unsafe extern "C" fn(
1204 handle: cudnnHandle_t,
1205 y_desc: cudnnTensorDescriptor_t,
1206 y: *mut c_void,
1207 value_ptr: *const c_void,
1208) -> cudnnStatus_t;
1209
1210pub type PFN_cudnnCreateLRNDescriptor =
1214 unsafe extern "C" fn(desc: *mut cudnnLRNDescriptor_t) -> cudnnStatus_t;
1215pub type PFN_cudnnDestroyLRNDescriptor =
1217 unsafe extern "C" fn(desc: cudnnLRNDescriptor_t) -> cudnnStatus_t;
1218pub type PFN_cudnnSetLRNDescriptor = unsafe extern "C" fn(
1220 desc: cudnnLRNDescriptor_t,
1221 lrn_n: c_int,
1222 lrn_alpha: c_double,
1223 lrn_beta: c_double,
1224 lrn_k: c_double,
1225) -> cudnnStatus_t;
1226pub type PFN_cudnnLRNCrossChannelForward = unsafe extern "C" fn(
1228 handle: cudnnHandle_t,
1229 lrn_desc: cudnnLRNDescriptor_t,
1230 lrn_mode: c_int,
1231 alpha: *const c_void,
1232 x_desc: cudnnTensorDescriptor_t,
1233 x: *const c_void,
1234 beta: *const c_void,
1235 y_desc: cudnnTensorDescriptor_t,
1236 y: *mut c_void,
1237) -> cudnnStatus_t;
1238
1239pub type PFN_cudnnCreateDropoutDescriptor =
1243 unsafe extern "C" fn(desc: *mut cudnnDropoutDescriptor_t) -> cudnnStatus_t;
1244pub type PFN_cudnnDestroyDropoutDescriptor =
1246 unsafe extern "C" fn(desc: cudnnDropoutDescriptor_t) -> cudnnStatus_t;
1247pub type PFN_cudnnDropoutGetStatesSize =
1249 unsafe extern "C" fn(handle: cudnnHandle_t, size_in_bytes: *mut usize) -> cudnnStatus_t;
1250pub type PFN_cudnnDropoutGetReserveSpaceSize = unsafe extern "C" fn(
1252 x_desc: cudnnTensorDescriptor_t,
1253 size_in_bytes: *mut usize,
1254) -> cudnnStatus_t;
1255pub type PFN_cudnnSetDropoutDescriptor = unsafe extern "C" fn(
1257 desc: cudnnDropoutDescriptor_t,
1258 handle: cudnnHandle_t,
1259 dropout: f32,
1260 states: *mut c_void,
1261 state_size: usize,
1262 seed: u64,
1263) -> cudnnStatus_t;
1264pub type PFN_cudnnDropoutForward = unsafe extern "C" fn(
1266 handle: cudnnHandle_t,
1267 desc: cudnnDropoutDescriptor_t,
1268 x_desc: cudnnTensorDescriptor_t,
1269 x: *const c_void,
1270 y_desc: cudnnTensorDescriptor_t,
1271 y: *mut c_void,
1272 reserve_space: *mut c_void,
1273 reserve_space_size: usize,
1274) -> cudnnStatus_t;
1275pub type PFN_cudnnDropoutBackward = unsafe extern "C" fn(
1277 handle: cudnnHandle_t,
1278 desc: cudnnDropoutDescriptor_t,
1279 dy_desc: cudnnTensorDescriptor_t,
1280 dy: *const c_void,
1281 dx_desc: cudnnTensorDescriptor_t,
1282 dx: *mut c_void,
1283 reserve_space: *mut c_void,
1284 reserve_space_size: usize,
1285) -> cudnnStatus_t;
1286
1287pub type PFN_cudnnCreateRNNDescriptor =
1291 unsafe extern "C" fn(desc: *mut cudnnRNNDescriptor_t) -> cudnnStatus_t;
1292pub type PFN_cudnnDestroyRNNDescriptor =
1294 unsafe extern "C" fn(desc: cudnnRNNDescriptor_t) -> cudnnStatus_t;
1295
1296pub type PFN_cudnnCreateRNNDataDescriptor =
1298 unsafe extern "C" fn(desc: *mut cudnnRNNDataDescriptor_t) -> cudnnStatus_t;
1299pub type PFN_cudnnDestroyRNNDataDescriptor =
1301 unsafe extern "C" fn(desc: cudnnRNNDataDescriptor_t) -> cudnnStatus_t;
1302
1303pub type PFN_cudnnRNNForward = unsafe extern "C" fn(
1305 handle: cudnnHandle_t,
1306 rnn_desc: cudnnRNNDescriptor_t,
1307 fwd_mode: c_int,
1308 dev_seq_lengths: *const i32,
1309 x_desc: cudnnRNNDataDescriptor_t,
1310 x: *const c_void,
1311 y_desc: cudnnRNNDataDescriptor_t,
1312 y: *mut c_void,
1313 h_desc: cudnnTensorDescriptor_t,
1314 hx: *const c_void,
1315 hy: *mut c_void,
1316 c_desc: cudnnTensorDescriptor_t,
1317 cx: *const c_void,
1318 cy: *mut c_void,
1319 weight_space_size: usize,
1320 weight_space: *const c_void,
1321 work_space_size: usize,
1322 work_space: *mut c_void,
1323 reserve_space_size: usize,
1324 reserve_space: *mut c_void,
1325) -> cudnnStatus_t;
1326
1327pub type PFN_cudnnBackendCreateDescriptor = unsafe extern "C" fn(
1331 descriptor_type: cudnnBackendDescriptorType_t,
1332 descriptor: *mut cudnnBackendDescriptor_t,
1333) -> cudnnStatus_t;
1334pub type PFN_cudnnBackendDestroyDescriptor =
1336 unsafe extern "C" fn(descriptor: cudnnBackendDescriptor_t) -> cudnnStatus_t;
1337pub type PFN_cudnnBackendInitialize =
1339 unsafe extern "C" fn(descriptor: cudnnBackendDescriptor_t) -> cudnnStatus_t;
1340pub type PFN_cudnnBackendFinalize =
1342 unsafe extern "C" fn(descriptor: cudnnBackendDescriptor_t) -> cudnnStatus_t;
1343pub type PFN_cudnnBackendSetAttribute = unsafe extern "C" fn(
1345 descriptor: cudnnBackendDescriptor_t,
1346 attribute_name: cudnnBackendAttributeName_t,
1347 attribute_type: cudnnBackendAttributeType_t,
1348 element_count: i64,
1349 array_of_elements: *const c_void,
1350) -> cudnnStatus_t;
1351pub type PFN_cudnnBackendGetAttribute = unsafe extern "C" fn(
1353 descriptor: cudnnBackendDescriptor_t,
1354 attribute_name: cudnnBackendAttributeName_t,
1355 attribute_type: cudnnBackendAttributeType_t,
1356 requested_element_count: i64,
1357 element_count: *mut i64,
1358 array_of_elements: *mut c_void,
1359) -> cudnnStatus_t;
1360pub type PFN_cudnnBackendExecute = unsafe extern "C" fn(
1362 handle: cudnnHandle_t,
1363 execution_plan: cudnnBackendDescriptor_t,
1364 variant_pack: cudnnBackendDescriptor_t,
1365) -> cudnnStatus_t;
1366
1367pub type PFN_cudnnGetErrorString =
1371 unsafe extern "C" fn(status: cudnnStatus_t) -> *const core::ffi::c_char;
1372
1373pub type PFN_cudnnSetTensorNdDescriptor = unsafe extern "C" fn(
1377 desc: cudnnTensorDescriptor_t,
1378 data_type: cudnnDataType_t,
1379 nb_dims: c_int,
1380 dim_a: *const c_int,
1381 stride_a: *const c_int,
1382) -> cudnnStatus_t;
1383
1384pub type PFN_cudnnGetTensorNdDescriptor = unsafe extern "C" fn(
1386 desc: cudnnTensorDescriptor_t,
1387 nb_dims_requested: c_int,
1388 data_type: *mut cudnnDataType_t,
1389 nb_dims: *mut c_int,
1390 dim_a: *mut c_int,
1391 stride_a: *mut c_int,
1392) -> cudnnStatus_t;
1393
1394pub type PFN_cudnnSetFilterNdDescriptor = unsafe extern "C" fn(
1396 desc: cudnnFilterDescriptor_t,
1397 data_type: cudnnDataType_t,
1398 format: cudnnTensorFormat_t,
1399 nb_dims: c_int,
1400 filter_dim_a: *const c_int,
1401) -> cudnnStatus_t;
1402
1403pub type PFN_cudnnSetConvolutionNdDescriptor = unsafe extern "C" fn(
1405 desc: cudnnConvolutionDescriptor_t,
1406 array_length: c_int,
1407 pad_a: *const c_int,
1408 filter_stride_a: *const c_int,
1409 dilation_a: *const c_int,
1410 mode: cudnnConvolutionMode_t,
1411 compute_type: cudnnDataType_t,
1412) -> cudnnStatus_t;
1413
1414pub type PFN_cudnnSetPoolingNdDescriptor = unsafe extern "C" fn(
1416 desc: cudnnPoolingDescriptor_t,
1417 mode: cudnnPoolingMode_t,
1418 nan_prop: cudnnNanPropagation_t,
1419 nb_dims: c_int,
1420 window_dim_a: *const c_int,
1421 padding_a: *const c_int,
1422 stride_a: *const c_int,
1423) -> cudnnStatus_t;
1424
1425pub type PFN_cudnnCreateCTCLossDescriptor =
1429 unsafe extern "C" fn(desc: *mut cudnnCTCLossDescriptor_t) -> cudnnStatus_t;
1430pub type PFN_cudnnDestroyCTCLossDescriptor =
1432 unsafe extern "C" fn(desc: cudnnCTCLossDescriptor_t) -> cudnnStatus_t;
1433pub type PFN_cudnnSetCTCLossDescriptor = unsafe extern "C" fn(
1435 desc: cudnnCTCLossDescriptor_t,
1436 compute_type: cudnnDataType_t,
1437) -> cudnnStatus_t;
1438
1439pub type PFN_cudnnGetCTCLossWorkspaceSize = unsafe extern "C" fn(
1441 handle: cudnnHandle_t,
1442 probs_desc: cudnnTensorDescriptor_t,
1443 gradients_desc: cudnnTensorDescriptor_t,
1444 labels: *const c_int,
1445 label_lengths: *const c_int,
1446 input_lengths: *const c_int,
1447 ctc_loss_algo: c_int,
1448 ctc_loss_desc: cudnnCTCLossDescriptor_t,
1449 size_in_bytes: *mut usize,
1450) -> cudnnStatus_t;
1451
1452pub type PFN_cudnnCTCLoss = unsafe extern "C" fn(
1454 handle: cudnnHandle_t,
1455 probs_desc: cudnnTensorDescriptor_t,
1456 probs: *const c_void,
1457 labels: *const c_int,
1458 label_lengths: *const c_int,
1459 input_lengths: *const c_int,
1460 costs: *mut c_void,
1461 gradients_desc: cudnnTensorDescriptor_t,
1462 gradients: *mut c_void,
1463 ctc_loss_algo: c_int,
1464 ctc_loss_desc: cudnnCTCLossDescriptor_t,
1465 workspace: *mut c_void,
1466 workspace_size: usize,
1467) -> cudnnStatus_t;
1468
1469pub type PFN_cudnnRNNBackwardData_v8 = unsafe extern "C" fn(
1473 handle: cudnnHandle_t,
1474 rnn_desc: cudnnRNNDescriptor_t,
1475 dev_seq_lengths: *const i32,
1476 y_desc: cudnnRNNDataDescriptor_t,
1477 y: *const c_void,
1478 dy: *const c_void,
1479 x_desc: cudnnRNNDataDescriptor_t,
1480 dx: *mut c_void,
1481 h_desc: cudnnTensorDescriptor_t,
1482 hx: *const c_void,
1483 dhy: *const c_void,
1484 dhx: *mut c_void,
1485 c_desc: cudnnTensorDescriptor_t,
1486 cx: *const c_void,
1487 dcy: *const c_void,
1488 dcx: *mut c_void,
1489 weight_space_size: usize,
1490 weight_space: *const c_void,
1491 work_space_size: usize,
1492 work_space: *mut c_void,
1493 reserve_space_size: usize,
1494 reserve_space: *mut c_void,
1495) -> cudnnStatus_t;
1496
1497pub type PFN_cudnnRNNBackwardWeights_v8 = unsafe extern "C" fn(
1499 handle: cudnnHandle_t,
1500 rnn_desc: cudnnRNNDescriptor_t,
1501 add_grad: c_int,
1502 dev_seq_lengths: *const i32,
1503 x_desc: cudnnRNNDataDescriptor_t,
1504 x: *const c_void,
1505 h_desc: cudnnTensorDescriptor_t,
1506 hx: *const c_void,
1507 y_desc: cudnnRNNDataDescriptor_t,
1508 y: *const c_void,
1509 weight_space_size: usize,
1510 dweight_space: *mut c_void,
1511 work_space_size: usize,
1512 work_space: *mut c_void,
1513 reserve_space_size: usize,
1514 reserve_space: *mut c_void,
1515) -> cudnnStatus_t;
1516
1517pub type cudnnSpatialTransformerDescriptor_t = *mut c_void;
1521
1522pub type PFN_cudnnCreateSpatialTransformerDescriptor =
1524 unsafe extern "C" fn(desc: *mut cudnnSpatialTransformerDescriptor_t) -> cudnnStatus_t;
1525pub type PFN_cudnnDestroySpatialTransformerDescriptor =
1527 unsafe extern "C" fn(desc: cudnnSpatialTransformerDescriptor_t) -> cudnnStatus_t;
1528
1529pub type PFN_cudnnSetSpatialTransformerNdDescriptor = unsafe extern "C" fn(
1531 desc: cudnnSpatialTransformerDescriptor_t,
1532 sampler_type: c_int,
1533 data_type: cudnnDataType_t,
1534 nb_dims: c_int,
1535 dim_a: *const c_int,
1536) -> cudnnStatus_t;
1537
1538pub type PFN_cudnnSpatialTfGridGeneratorForward = unsafe extern "C" fn(
1540 handle: cudnnHandle_t,
1541 st_desc: cudnnSpatialTransformerDescriptor_t,
1542 theta: *const c_void,
1543 grid: *mut c_void,
1544) -> cudnnStatus_t;
1545
1546pub type PFN_cudnnSpatialTfSamplerForward = unsafe extern "C" fn(
1548 handle: cudnnHandle_t,
1549 st_desc: cudnnSpatialTransformerDescriptor_t,
1550 alpha: *const c_void,
1551 x_desc: cudnnTensorDescriptor_t,
1552 x: *const c_void,
1553 grid: *const c_void,
1554 beta: *const c_void,
1555 y_desc: cudnnTensorDescriptor_t,
1556 y: *mut c_void,
1557) -> cudnnStatus_t;
1558
1559pub type PFN_cudnnSetConvolutionGroupCount = unsafe extern "C" fn(
1565 desc: cudnnConvolutionDescriptor_t,
1566 group_count: c_int,
1567) -> cudnnStatus_t;
1568pub type PFN_cudnnGetConvolutionGroupCount = unsafe extern "C" fn(
1570 desc: cudnnConvolutionDescriptor_t,
1571 group_count: *mut c_int,
1572) -> cudnnStatus_t;
1573
1574pub type PFN_cudnnSetConvolutionMathType = unsafe extern "C" fn(
1576 desc: cudnnConvolutionDescriptor_t,
1577 math_type: cudnnMathType_t,
1578) -> cudnnStatus_t;
1579pub type PFN_cudnnGetConvolutionMathType = unsafe extern "C" fn(
1581 desc: cudnnConvolutionDescriptor_t,
1582 math_type: *mut cudnnMathType_t,
1583) -> cudnnStatus_t;
1584
1585pub type PFN_cudnnSetConvolutionReorderType = unsafe extern "C" fn(
1587 desc: cudnnConvolutionDescriptor_t,
1588 reorder_type: cudnnReorderType_t,
1589) -> cudnnStatus_t;
1590pub type PFN_cudnnGetConvolutionReorderType = unsafe extern "C" fn(
1592 desc: cudnnConvolutionDescriptor_t,
1593 reorder_type: *mut cudnnReorderType_t,
1594) -> cudnnStatus_t;
1595
1596pub type PFN_cudnnReorderFilterAndBias = unsafe extern "C" fn(
1598 handle: cudnnHandle_t,
1599 filter_desc: cudnnFilterDescriptor_t,
1600 reorder_type: cudnnReorderType_t,
1601 filter_data: *const c_void,
1602 reordered_filter_data: *mut c_void,
1603 reorder_bias: c_int,
1604 bias_data: *const c_void,
1605 reordered_bias_data: *mut c_void,
1606) -> cudnnStatus_t;
1607
1608pub type PFN_cudnnConvolutionBiasActivationForward = unsafe extern "C" fn(
1610 handle: cudnnHandle_t,
1611 alpha1: *const c_void,
1612 x_desc: cudnnTensorDescriptor_t,
1613 x: *const c_void,
1614 w_desc: cudnnFilterDescriptor_t,
1615 w: *const c_void,
1616 conv_desc: cudnnConvolutionDescriptor_t,
1617 algo: cudnnConvolutionFwdAlgo_t,
1618 workspace: *mut c_void,
1619 workspace_size: usize,
1620 alpha2: *const c_void,
1621 z_desc: cudnnTensorDescriptor_t,
1622 z: *const c_void,
1623 bias_desc: cudnnTensorDescriptor_t,
1624 bias: *const c_void,
1625 activation_desc: cudnnActivationDescriptor_t,
1626 y_desc: cudnnTensorDescriptor_t,
1627 y: *mut c_void,
1628) -> cudnnStatus_t;
1629
1630pub type PFN_cudnnActivationBackward = unsafe extern "C" fn(
1632 handle: cudnnHandle_t,
1633 activation_desc: cudnnActivationDescriptor_t,
1634 alpha: *const c_void,
1635 y_desc: cudnnTensorDescriptor_t,
1636 y: *const c_void,
1637 dy_desc: cudnnTensorDescriptor_t,
1638 dy: *const c_void,
1639 x_desc: cudnnTensorDescriptor_t,
1640 x: *const c_void,
1641 beta: *const c_void,
1642 dx_desc: cudnnTensorDescriptor_t,
1643 dx: *mut c_void,
1644) -> cudnnStatus_t;
1645
1646pub type PFN_cudnnSetActivationDescriptorSwishBeta = unsafe extern "C" fn(
1648 desc: cudnnActivationDescriptor_t,
1649 swish_beta: c_double,
1650) -> cudnnStatus_t;
1651pub type PFN_cudnnGetActivationDescriptorSwishBeta = unsafe extern "C" fn(
1653 desc: cudnnActivationDescriptor_t,
1654 swish_beta: *mut c_double,
1655) -> cudnnStatus_t;
1656
1657pub type PFN_cudnnLRNCrossChannelBackward = unsafe extern "C" fn(
1659 handle: cudnnHandle_t,
1660 norm_desc: cudnnLRNDescriptor_t,
1661 lrn_mode: c_int,
1662 alpha: *const c_void,
1663 y_desc: cudnnTensorDescriptor_t,
1664 y: *const c_void,
1665 dy_desc: cudnnTensorDescriptor_t,
1666 dy: *const c_void,
1667 x_desc: cudnnTensorDescriptor_t,
1668 x: *const c_void,
1669 beta: *const c_void,
1670 dx_desc: cudnnTensorDescriptor_t,
1671 dx: *mut c_void,
1672) -> cudnnStatus_t;
1673
1674pub type PFN_cudnnDivisiveNormalizationForward = unsafe extern "C" fn(
1676 handle: cudnnHandle_t,
1677 norm_desc: cudnnLRNDescriptor_t,
1678 mode: c_int,
1679 alpha: *const c_void,
1680 x_desc: cudnnTensorDescriptor_t,
1681 x: *const c_void,
1682 means: *const c_void,
1683 temp: *mut c_void,
1684 temp2: *mut c_void,
1685 beta: *const c_void,
1686 y_desc: cudnnTensorDescriptor_t,
1687 y: *mut c_void,
1688) -> cudnnStatus_t;
1689
1690pub type PFN_cudnnDivisiveNormalizationBackward = unsafe extern "C" fn(
1692 handle: cudnnHandle_t,
1693 norm_desc: cudnnLRNDescriptor_t,
1694 mode: c_int,
1695 alpha: *const c_void,
1696 x_desc: cudnnTensorDescriptor_t,
1697 x: *const c_void,
1698 means: *const c_void,
1699 dy: *const c_void,
1700 temp: *mut c_void,
1701 temp2: *mut c_void,
1702 beta: *const c_void,
1703 d_xdmeans_desc: cudnnTensorDescriptor_t,
1704 dx: *mut c_void,
1705 d_means: *mut c_void,
1706) -> cudnnStatus_t;
1707
1708pub type PFN_cudnnGetReductionIndicesSize = unsafe extern "C" fn(
1710 handle: cudnnHandle_t,
1711 desc: cudnnReduceTensorDescriptor_t,
1712 a_desc: cudnnTensorDescriptor_t,
1713 c_desc: cudnnTensorDescriptor_t,
1714 size_in_bytes: *mut usize,
1715) -> cudnnStatus_t;
1716
1717pub type PFN_cudnnSetTensor4dDescriptorEx = unsafe extern "C" fn(
1721 desc: cudnnTensorDescriptor_t,
1722 data_type: cudnnDataType_t,
1723 n: c_int,
1724 c: c_int,
1725 h: c_int,
1726 w: c_int,
1727 n_stride: c_int,
1728 c_stride: c_int,
1729 h_stride: c_int,
1730 w_stride: c_int,
1731) -> cudnnStatus_t;
1732
1733pub type PFN_cudnnGetTensor4dDescriptor = unsafe extern "C" fn(
1735 desc: cudnnTensorDescriptor_t,
1736 data_type: *mut cudnnDataType_t,
1737 n: *mut c_int,
1738 c: *mut c_int,
1739 h: *mut c_int,
1740 w: *mut c_int,
1741 n_stride: *mut c_int,
1742 c_stride: *mut c_int,
1743 h_stride: *mut c_int,
1744 w_stride: *mut c_int,
1745) -> cudnnStatus_t;
1746
1747pub type PFN_cudnnGetFilter4dDescriptor = unsafe extern "C" fn(
1749 desc: cudnnFilterDescriptor_t,
1750 data_type: *mut cudnnDataType_t,
1751 format: *mut cudnnTensorFormat_t,
1752 k: *mut c_int,
1753 c: *mut c_int,
1754 h: *mut c_int,
1755 w: *mut c_int,
1756) -> cudnnStatus_t;
1757
1758pub type PFN_cudnnGetDropoutDescriptor = unsafe extern "C" fn(
1762 desc: cudnnDropoutDescriptor_t,
1763 handle: cudnnHandle_t,
1764 dropout: *mut f32,
1765 states: *mut *mut c_void,
1766 seed: *mut u64,
1767) -> cudnnStatus_t;
1768
1769pub type PFN_cudnnRestoreDropoutDescriptor = unsafe extern "C" fn(
1771 desc: cudnnDropoutDescriptor_t,
1772 handle: cudnnHandle_t,
1773 dropout: f32,
1774 states: *mut c_void,
1775 state_size: usize,
1776 seed: u64,
1777) -> cudnnStatus_t;
1778
1779pub type PFN_cudnnGetConvolutionForwardAlgorithm_v7 = unsafe extern "C" fn(
1785 handle: cudnnHandle_t,
1786 src_desc: cudnnTensorDescriptor_t,
1787 filter_desc: cudnnFilterDescriptor_t,
1788 conv_desc: cudnnConvolutionDescriptor_t,
1789 dst_desc: cudnnTensorDescriptor_t,
1790 requested_algo_count: c_int,
1791 returned_algo_count: *mut c_int,
1792 perf_results: *mut cudnnConvolutionFwdAlgoPerf_t,
1793) -> cudnnStatus_t;
1794
1795pub type PFN_cudnnFindConvolutionForwardAlgorithm = unsafe extern "C" fn(
1797 handle: cudnnHandle_t,
1798 src_desc: cudnnTensorDescriptor_t,
1799 filter_desc: cudnnFilterDescriptor_t,
1800 conv_desc: cudnnConvolutionDescriptor_t,
1801 dst_desc: cudnnTensorDescriptor_t,
1802 requested_algo_count: c_int,
1803 returned_algo_count: *mut c_int,
1804 perf_results: *mut cudnnConvolutionFwdAlgoPerf_t,
1805) -> cudnnStatus_t;
1806
1807pub type PFN_cudnnFindConvolutionForwardAlgorithmEx = unsafe extern "C" fn(
1809 handle: cudnnHandle_t,
1810 src_desc: cudnnTensorDescriptor_t,
1811 src: *const c_void,
1812 filter_desc: cudnnFilterDescriptor_t,
1813 filter: *const c_void,
1814 conv_desc: cudnnConvolutionDescriptor_t,
1815 dst_desc: cudnnTensorDescriptor_t,
1816 dst: *mut c_void,
1817 requested_algo_count: c_int,
1818 returned_algo_count: *mut c_int,
1819 perf_results: *mut cudnnConvolutionFwdAlgoPerf_t,
1820 workspace: *mut c_void,
1821 workspace_size: usize,
1822) -> cudnnStatus_t;
1823
1824pub type PFN_cudnnGetConvolutionBackwardDataAlgorithm_v7 = unsafe extern "C" fn(
1826 handle: cudnnHandle_t,
1827 filter_desc: cudnnFilterDescriptor_t,
1828 diff_desc: cudnnTensorDescriptor_t,
1829 conv_desc: cudnnConvolutionDescriptor_t,
1830 grad_desc: cudnnTensorDescriptor_t,
1831 requested_algo_count: c_int,
1832 returned_algo_count: *mut c_int,
1833 perf_results: *mut cudnnConvolutionBwdDataAlgoPerf_t,
1834) -> cudnnStatus_t;
1835
1836pub type PFN_cudnnFindConvolutionBackwardDataAlgorithm = unsafe extern "C" fn(
1838 handle: cudnnHandle_t,
1839 filter_desc: cudnnFilterDescriptor_t,
1840 diff_desc: cudnnTensorDescriptor_t,
1841 conv_desc: cudnnConvolutionDescriptor_t,
1842 grad_desc: cudnnTensorDescriptor_t,
1843 requested_algo_count: c_int,
1844 returned_algo_count: *mut c_int,
1845 perf_results: *mut cudnnConvolutionBwdDataAlgoPerf_t,
1846) -> cudnnStatus_t;
1847
1848pub type PFN_cudnnGetConvolutionBackwardFilterAlgorithm_v7 = unsafe extern "C" fn(
1850 handle: cudnnHandle_t,
1851 src_desc: cudnnTensorDescriptor_t,
1852 diff_desc: cudnnTensorDescriptor_t,
1853 conv_desc: cudnnConvolutionDescriptor_t,
1854 grad_desc: cudnnFilterDescriptor_t,
1855 requested_algo_count: c_int,
1856 returned_algo_count: *mut c_int,
1857 perf_results: *mut cudnnConvolutionBwdFilterAlgoPerf_t,
1858) -> cudnnStatus_t;
1859
1860pub type PFN_cudnnFindConvolutionBackwardFilterAlgorithm = unsafe extern "C" fn(
1862 handle: cudnnHandle_t,
1863 src_desc: cudnnTensorDescriptor_t,
1864 diff_desc: cudnnTensorDescriptor_t,
1865 conv_desc: cudnnConvolutionDescriptor_t,
1866 grad_desc: cudnnFilterDescriptor_t,
1867 requested_algo_count: c_int,
1868 returned_algo_count: *mut c_int,
1869 perf_results: *mut cudnnConvolutionBwdFilterAlgoPerf_t,
1870) -> cudnnStatus_t;
1871
1872pub type PFN_cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize =
1878 unsafe extern "C" fn(
1879 handle: cudnnHandle_t,
1880 mode: cudnnBatchNormMode_t,
1881 bn_ops: cudnnBatchNormOps_t,
1882 x_desc: cudnnTensorDescriptor_t,
1883 z_desc: cudnnTensorDescriptor_t,
1884 y_desc: cudnnTensorDescriptor_t,
1885 bn_scale_bias_mean_var_desc: cudnnTensorDescriptor_t,
1886 activation_desc: cudnnActivationDescriptor_t,
1887 size_in_bytes: *mut usize,
1888 ) -> cudnnStatus_t;
1889
1890pub type PFN_cudnnGetBatchNormalizationBackwardExWorkspaceSize =
1892 unsafe extern "C" fn(
1893 handle: cudnnHandle_t,
1894 mode: cudnnBatchNormMode_t,
1895 bn_ops: cudnnBatchNormOps_t,
1896 x_desc: cudnnTensorDescriptor_t,
1897 y_desc: cudnnTensorDescriptor_t,
1898 dy_desc: cudnnTensorDescriptor_t,
1899 dz_desc: cudnnTensorDescriptor_t,
1900 dx_desc: cudnnTensorDescriptor_t,
1901 d_bn_scale_bias_desc: cudnnTensorDescriptor_t,
1902 activation_desc: cudnnActivationDescriptor_t,
1903 size_in_bytes: *mut usize,
1904 ) -> cudnnStatus_t;
1905
1906pub type PFN_cudnnGetBatchNormalizationTrainingExReserveSpaceSize =
1908 unsafe extern "C" fn(
1909 handle: cudnnHandle_t,
1910 mode: cudnnBatchNormMode_t,
1911 bn_ops: cudnnBatchNormOps_t,
1912 activation_desc: cudnnActivationDescriptor_t,
1913 x_desc: cudnnTensorDescriptor_t,
1914 size_in_bytes: *mut usize,
1915 ) -> cudnnStatus_t;
1916
1917pub type PFN_cudnnBatchNormalizationForwardTrainingEx = unsafe extern "C" fn(
1919 handle: cudnnHandle_t,
1920 mode: cudnnBatchNormMode_t,
1921 bn_ops: cudnnBatchNormOps_t,
1922 alpha: *const c_void,
1923 beta: *const c_void,
1924 x_desc: cudnnTensorDescriptor_t,
1925 x: *const c_void,
1926 z_desc: cudnnTensorDescriptor_t,
1927 z: *const c_void,
1928 y_desc: cudnnTensorDescriptor_t,
1929 y: *mut c_void,
1930 bn_scale_bias_mean_var_desc: cudnnTensorDescriptor_t,
1931 bn_scale: *const c_void,
1932 bn_bias: *const c_void,
1933 exponential_average_factor: c_double,
1934 result_running_mean: *mut c_void,
1935 result_running_variance: *mut c_void,
1936 epsilon: c_double,
1937 save_mean: *mut c_void,
1938 save_inv_variance: *mut c_void,
1939 activation_desc: cudnnActivationDescriptor_t,
1940 workspace: *mut c_void,
1941 workspace_size: usize,
1942 reserve_space: *mut c_void,
1943 reserve_space_size: usize,
1944) -> cudnnStatus_t;
1945
1946pub type PFN_cudnnBatchNormalizationBackwardEx = unsafe extern "C" fn(
1948 handle: cudnnHandle_t,
1949 mode: cudnnBatchNormMode_t,
1950 bn_ops: cudnnBatchNormOps_t,
1951 alpha_data_diff: *const c_void,
1952 beta_data_diff: *const c_void,
1953 alpha_param_diff: *const c_void,
1954 beta_param_diff: *const c_void,
1955 x_desc: cudnnTensorDescriptor_t,
1956 x: *const c_void,
1957 y_desc: cudnnTensorDescriptor_t,
1958 y: *const c_void,
1959 dy_desc: cudnnTensorDescriptor_t,
1960 dy: *const c_void,
1961 dz_desc: cudnnTensorDescriptor_t,
1962 dz: *mut c_void,
1963 dx_desc: cudnnTensorDescriptor_t,
1964 dx: *mut c_void,
1965 d_bn_scale_bias_desc: cudnnTensorDescriptor_t,
1966 bn_scale: *const c_void,
1967 bn_bias: *const c_void,
1968 d_bn_scale_result: *mut c_void,
1969 d_bn_bias_result: *mut c_void,
1970 epsilon: c_double,
1971 saved_mean: *const c_void,
1972 saved_inv_variance: *const c_void,
1973 activation_desc: cudnnActivationDescriptor_t,
1974 workspace: *mut c_void,
1975 workspace_size: usize,
1976 reserve_space: *mut c_void,
1977 reserve_space_size: usize,
1978) -> cudnnStatus_t;
1979
1980pub type PFN_cudnnNormalizationForwardInference = unsafe extern "C" fn(
1984 handle: cudnnHandle_t,
1985 mode: cudnnNormMode_t,
1986 norm_ops: cudnnNormOps_t,
1987 algo: cudnnNormAlgo_t,
1988 alpha: *const c_void,
1989 beta: *const c_void,
1990 x_desc: cudnnTensorDescriptor_t,
1991 x: *const c_void,
1992 norm_scale_bias_desc: cudnnTensorDescriptor_t,
1993 norm_scale: *const c_void,
1994 norm_bias: *const c_void,
1995 norm_mean_var_desc: cudnnTensorDescriptor_t,
1996 estimated_mean: *const c_void,
1997 estimated_variance: *const c_void,
1998 z_desc: cudnnTensorDescriptor_t,
1999 z: *const c_void,
2000 activation_desc: cudnnActivationDescriptor_t,
2001 y_desc: cudnnTensorDescriptor_t,
2002 y: *mut c_void,
2003 epsilon: c_double,
2004 group_count: c_int,
2005) -> cudnnStatus_t;
2006
2007pub type PFN_cudnnGetNormalizationForwardTrainingWorkspaceSize =
2009 unsafe extern "C" fn(
2010 handle: cudnnHandle_t,
2011 mode: cudnnNormMode_t,
2012 norm_ops: cudnnNormOps_t,
2013 algo: cudnnNormAlgo_t,
2014 x_desc: cudnnTensorDescriptor_t,
2015 z_desc: cudnnTensorDescriptor_t,
2016 y_desc: cudnnTensorDescriptor_t,
2017 norm_scale_bias_desc: cudnnTensorDescriptor_t,
2018 activation_desc: cudnnActivationDescriptor_t,
2019 norm_mean_var_desc: cudnnTensorDescriptor_t,
2020 size_in_bytes: *mut usize,
2021 group_count: c_int,
2022 ) -> cudnnStatus_t;
2023
2024pub type PFN_cudnnGetNormalizationBackwardWorkspaceSize = unsafe extern "C" fn(
2026 handle: cudnnHandle_t,
2027 mode: cudnnNormMode_t,
2028 norm_ops: cudnnNormOps_t,
2029 algo: cudnnNormAlgo_t,
2030 x_desc: cudnnTensorDescriptor_t,
2031 y_desc: cudnnTensorDescriptor_t,
2032 dy_desc: cudnnTensorDescriptor_t,
2033 dz_desc: cudnnTensorDescriptor_t,
2034 dx_desc: cudnnTensorDescriptor_t,
2035 d_norm_scale_bias_desc: cudnnTensorDescriptor_t,
2036 activation_desc: cudnnActivationDescriptor_t,
2037 norm_mean_var_desc: cudnnTensorDescriptor_t,
2038 size_in_bytes: *mut usize,
2039 group_count: c_int,
2040) -> cudnnStatus_t;
2041
2042pub type PFN_cudnnGetNormalizationTrainingReserveSpaceSize =
2044 unsafe extern "C" fn(
2045 handle: cudnnHandle_t,
2046 mode: cudnnNormMode_t,
2047 norm_ops: cudnnNormOps_t,
2048 algo: cudnnNormAlgo_t,
2049 activation_desc: cudnnActivationDescriptor_t,
2050 x_desc: cudnnTensorDescriptor_t,
2051 size_in_bytes: *mut usize,
2052 group_count: c_int,
2053 ) -> cudnnStatus_t;
2054
2055pub type PFN_cudnnNormalizationForwardTraining = unsafe extern "C" fn(
2057 handle: cudnnHandle_t,
2058 mode: cudnnNormMode_t,
2059 norm_ops: cudnnNormOps_t,
2060 algo: cudnnNormAlgo_t,
2061 alpha: *const c_void,
2062 beta: *const c_void,
2063 x_desc: cudnnTensorDescriptor_t,
2064 x: *const c_void,
2065 norm_scale_bias_desc: cudnnTensorDescriptor_t,
2066 norm_scale: *const c_void,
2067 norm_bias: *const c_void,
2068 exponential_average_factor: c_double,
2069 norm_mean_var_desc: cudnnTensorDescriptor_t,
2070 result_running_mean: *mut c_void,
2071 result_running_variance: *mut c_void,
2072 epsilon: c_double,
2073 save_mean: *mut c_void,
2074 save_inv_variance: *mut c_void,
2075 activation_desc: cudnnActivationDescriptor_t,
2076 z_desc: cudnnTensorDescriptor_t,
2077 z: *const c_void,
2078 y_desc: cudnnTensorDescriptor_t,
2079 y: *mut c_void,
2080 workspace: *mut c_void,
2081 workspace_size: usize,
2082 reserve_space: *mut c_void,
2083 reserve_space_size: usize,
2084 group_count: c_int,
2085) -> cudnnStatus_t;
2086
2087pub type PFN_cudnnNormalizationBackward = unsafe extern "C" fn(
2089 handle: cudnnHandle_t,
2090 mode: cudnnNormMode_t,
2091 norm_ops: cudnnNormOps_t,
2092 algo: cudnnNormAlgo_t,
2093 alpha_data_diff: *const c_void,
2094 beta_data_diff: *const c_void,
2095 alpha_param_diff: *const c_void,
2096 beta_param_diff: *const c_void,
2097 x_desc: cudnnTensorDescriptor_t,
2098 x: *const c_void,
2099 y_desc: cudnnTensorDescriptor_t,
2100 y: *const c_void,
2101 dy_desc: cudnnTensorDescriptor_t,
2102 dy: *const c_void,
2103 dz_desc: cudnnTensorDescriptor_t,
2104 dz: *mut c_void,
2105 dx_desc: cudnnTensorDescriptor_t,
2106 dx: *mut c_void,
2107 d_norm_scale_bias_desc: cudnnTensorDescriptor_t,
2108 norm_scale: *const c_void,
2109 norm_bias: *const c_void,
2110 d_norm_scale: *mut c_void,
2111 d_norm_bias: *mut c_void,
2112 epsilon: c_double,
2113 norm_mean_var_desc: cudnnTensorDescriptor_t,
2114 saved_mean: *const c_void,
2115 saved_inv_variance: *const c_void,
2116 activation_desc: cudnnActivationDescriptor_t,
2117 workspace: *mut c_void,
2118 workspace_size: usize,
2119 reserve_space: *mut c_void,
2120 reserve_space_size: usize,
2121 group_count: c_int,
2122) -> cudnnStatus_t;
2123
2124pub type PFN_cudnnSetRNNDescriptor_v8 = unsafe extern "C" fn(
2130 rnn_desc: cudnnRNNDescriptor_t,
2131 algo: cudnnRNNAlgo_t,
2132 cell_mode: cudnnRNNMode_t,
2133 bias_mode: c_int,
2134 dir_mode: cudnnDirectionMode_t,
2135 input_mode: cudnnRNNInputMode_t,
2136 data_type: cudnnDataType_t,
2137 math_prec: cudnnDataType_t,
2138 math_type: cudnnMathType_t,
2139 input_size: i32,
2140 hidden_size: i32,
2141 proj_size: i32,
2142 num_layers: i32,
2143 dropout_desc: cudnnDropoutDescriptor_t,
2144 aux_flags: u32,
2145) -> cudnnStatus_t;
2146
2147pub type PFN_cudnnBuildRNNDynamic = unsafe extern "C" fn(
2149 handle: cudnnHandle_t,
2150 rnn_desc: cudnnRNNDescriptor_t,
2151 mini_batch: c_int,
2152) -> cudnnStatus_t;
2153
2154pub type PFN_cudnnGetRNNTempSpaceSizes = unsafe extern "C" fn(
2156 handle: cudnnHandle_t,
2157 rnn_desc: cudnnRNNDescriptor_t,
2158 fwd_mode: c_int,
2159 x_desc: cudnnRNNDataDescriptor_t,
2160 work_space_size: *mut usize,
2161 reserve_space_size: *mut usize,
2162) -> cudnnStatus_t;
2163
2164pub type PFN_cudnnGetRNNWeightSpaceSize = unsafe extern "C" fn(
2166 handle: cudnnHandle_t,
2167 rnn_desc: cudnnRNNDescriptor_t,
2168 weight_space_size: *mut usize,
2169) -> cudnnStatus_t;
2170
2171pub type PFN_cudnnGetRNNWeightParams = unsafe extern "C" fn(
2173 handle: cudnnHandle_t,
2174 rnn_desc: cudnnRNNDescriptor_t,
2175 pseudo_layer: i32,
2176 weight_space_size: usize,
2177 weight_space: *const c_void,
2178 lin_layer_id: i32,
2179 m_desc: cudnnTensorDescriptor_t,
2180 m_addr: *mut *mut c_void,
2181 b_desc: cudnnTensorDescriptor_t,
2182 b_addr: *mut *mut c_void,
2183) -> cudnnStatus_t;
2184
2185pub type PFN_cudnnCreateAttnDescriptor =
2191 unsafe extern "C" fn(desc: *mut cudnnAttnDescriptor_t) -> cudnnStatus_t;
2192pub type PFN_cudnnDestroyAttnDescriptor =
2194 unsafe extern "C" fn(desc: cudnnAttnDescriptor_t) -> cudnnStatus_t;
2195
2196pub type PFN_cudnnSetAttnDescriptor = unsafe extern "C" fn(
2198 desc: cudnnAttnDescriptor_t,
2199 attn_mode: u32,
2200 n_heads: i32,
2201 sm_scaler: c_double,
2202 data_type: cudnnDataType_t,
2203 compute_prec: cudnnDataType_t,
2204 math_type: cudnnMathType_t,
2205 attn_dropout_desc: cudnnDropoutDescriptor_t,
2206 post_dropout_desc: cudnnDropoutDescriptor_t,
2207 q_size: i32,
2208 k_size: i32,
2209 v_size: i32,
2210 q_proj_size: i32,
2211 k_proj_size: i32,
2212 v_proj_size: i32,
2213 o_proj_size: i32,
2214 qo_max_seq_length: i32,
2215 kv_max_seq_length: i32,
2216 max_batch_size: i32,
2217 max_beam_size: i32,
2218) -> cudnnStatus_t;
2219
2220pub type PFN_cudnnGetMultiHeadAttnBuffers = unsafe extern "C" fn(
2222 handle: cudnnHandle_t,
2223 attn_desc: cudnnAttnDescriptor_t,
2224 weight_size_in_bytes: *mut usize,
2225 work_space_size_in_bytes: *mut usize,
2226 reserve_space_size_in_bytes: *mut usize,
2227) -> cudnnStatus_t;
2228
2229pub type PFN_cudnnGetMultiHeadAttnWeights = unsafe extern "C" fn(
2231 handle: cudnnHandle_t,
2232 attn_desc: cudnnAttnDescriptor_t,
2233 w_kind: c_int,
2234 weight_size_in_bytes: usize,
2235 weights: *const c_void,
2236 w_desc: cudnnTensorDescriptor_t,
2237 w_addr: *mut *mut c_void,
2238) -> cudnnStatus_t;
2239
2240pub type PFN_cudnnMultiHeadAttnForward = unsafe extern "C" fn(
2242 handle: cudnnHandle_t,
2243 attn_desc: cudnnAttnDescriptor_t,
2244 curr_idx: i32,
2245 lo_win_idx: *const i32,
2246 hi_win_idx: *const i32,
2247 dev_seq_lengths_qo: *const i32,
2248 dev_seq_lengths_kv: *const i32,
2249 q_desc: cudnnSeqDataDescriptor_t,
2250 queries: *const c_void,
2251 residuals: *const c_void,
2252 k_desc: cudnnSeqDataDescriptor_t,
2253 keys: *const c_void,
2254 v_desc: cudnnSeqDataDescriptor_t,
2255 values: *const c_void,
2256 o_desc: cudnnSeqDataDescriptor_t,
2257 out: *mut c_void,
2258 weight_size_in_bytes: usize,
2259 weights: *const c_void,
2260 work_space_size_in_bytes: usize,
2261 work_space: *mut c_void,
2262 reserve_space_size_in_bytes: usize,
2263 reserve_space: *mut c_void,
2264) -> cudnnStatus_t;
2265
2266pub type PFN_cudnnMultiHeadAttnBackwardData = unsafe extern "C" fn(
2268 handle: cudnnHandle_t,
2269 attn_desc: cudnnAttnDescriptor_t,
2270 lo_win_idx: *const i32,
2271 hi_win_idx: *const i32,
2272 dev_seq_lengths_dqdo: *const i32,
2273 dev_seq_lengths_dkdv: *const i32,
2274 do_desc: cudnnSeqDataDescriptor_t,
2275 dout: *const c_void,
2276 dq_desc: cudnnSeqDataDescriptor_t,
2277 dqueries: *mut c_void,
2278 queries: *const c_void,
2279 dk_desc: cudnnSeqDataDescriptor_t,
2280 dkeys: *mut c_void,
2281 keys: *const c_void,
2282 dv_desc: cudnnSeqDataDescriptor_t,
2283 dvalues: *mut c_void,
2284 values: *const c_void,
2285 weight_size_in_bytes: usize,
2286 weights: *const c_void,
2287 work_space_size_in_bytes: usize,
2288 work_space: *mut c_void,
2289 reserve_space_size_in_bytes: usize,
2290 reserve_space: *mut c_void,
2291) -> cudnnStatus_t;
2292
2293pub type PFN_cudnnMultiHeadAttnBackwardWeights = unsafe extern "C" fn(
2295 handle: cudnnHandle_t,
2296 attn_desc: cudnnAttnDescriptor_t,
2297 add_grad: c_int,
2298 q_desc: cudnnSeqDataDescriptor_t,
2299 queries: *const c_void,
2300 k_desc: cudnnSeqDataDescriptor_t,
2301 keys: *const c_void,
2302 v_desc: cudnnSeqDataDescriptor_t,
2303 values: *const c_void,
2304 do_desc: cudnnSeqDataDescriptor_t,
2305 dout: *const c_void,
2306 weight_size_in_bytes: usize,
2307 weights: *const c_void,
2308 dweights: *mut c_void,
2309 work_space_size_in_bytes: usize,
2310 work_space: *mut c_void,
2311 reserve_space_size_in_bytes: usize,
2312 reserve_space: *mut c_void,
2313) -> cudnnStatus_t;
2314
2315pub type PFN_cudnnCreateSeqDataDescriptor =
2319 unsafe extern "C" fn(desc: *mut cudnnSeqDataDescriptor_t) -> cudnnStatus_t;
2320pub type PFN_cudnnDestroySeqDataDescriptor =
2322 unsafe extern "C" fn(desc: cudnnSeqDataDescriptor_t) -> cudnnStatus_t;
2323pub type PFN_cudnnSetSeqDataDescriptor = unsafe extern "C" fn(
2325 desc: cudnnSeqDataDescriptor_t,
2326 data_type: cudnnDataType_t,
2327 nb_dims: c_int,
2328 dim_a: *const c_int,
2329 axes: *const c_int,
2330 seq_length_array_size: usize,
2331 seq_length_array: *const c_int,
2332 padding_fill: *const c_void,
2333) -> cudnnStatus_t;
2334
2335fn cudnn_candidates() -> Vec<String> {
2341 #[cfg(target_os = "linux")]
2342 {
2343 vec![
2344 "libcudnn.so.9".to_string(),
2345 "libcudnn.so.8".to_string(),
2346 "libcudnn.so".to_string(),
2347 ]
2348 }
2349 #[cfg(target_os = "windows")]
2350 {
2351 vec!["cudnn64_9.dll".to_string(), "cudnn64_8.dll".to_string()]
2352 }
2353 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
2354 {
2355 Vec::new()
2356 }
2357}
2358
2359fn detect_cuda_major() -> Option<u32> {
2369 if let Ok(p) = std::env::var("CUDA_PATH") {
2370 if let Some(n) = parse_cuda_major_from_path(&p) {
2371 return Some(n);
2372 }
2373 }
2374 let mut best: Option<u32> = None;
2376 for (k, _) in std::env::vars() {
2377 if let Some(rest) = k.strip_prefix("CUDA_PATH_V") {
2378 if let Some((maj, _)) = rest.split_once('_') {
2380 if let Ok(n) = maj.parse::<u32>() {
2381 best = Some(best.map_or(n, |b| b.max(n)));
2382 }
2383 }
2384 }
2385 }
2386 best
2387}
2388
2389fn parse_cuda_major_from_path(p: &str) -> Option<u32> {
2392 for component in p.split(['/', '\\']) {
2393 let bytes = component.as_bytes();
2394 let rest = if bytes.first() == Some(&b'v') || bytes.first() == Some(&b'V') {
2395 &component[1..]
2396 } else {
2397 continue;
2398 };
2399 let dot = rest.find('.')?;
2400 if let Ok(n) = rest[..dot].parse::<u32>() {
2401 return Some(n);
2402 }
2403 }
2404 None
2405}
2406
2407fn cudnn_extra_search_dirs() -> Vec<PathBuf> {
2408 let mut out = Vec::new();
2409
2410 if let Ok(p) = std::env::var("CUDNN_PATH") {
2411 let base = PathBuf::from(p);
2412 if cfg!(target_os = "windows") {
2413 out.push(base.join("bin"));
2414 } else {
2415 out.push(base.join("lib"));
2416 out.push(base.join("lib64"));
2417 }
2418 }
2419
2420 if cfg!(target_os = "windows") {
2421 let target_major = detect_cuda_major();
2429 if let Ok(pf) = std::env::var("ProgramFiles") {
2430 let cudnn_root = PathBuf::from(pf).join("NVIDIA").join("CUDNN");
2431 if let Ok(read_dir) = std::fs::read_dir(&cudnn_root) {
2432 for entry in read_dir.flatten() {
2433 let p = entry.path();
2434 if !p.is_dir() {
2435 continue;
2436 }
2437 let bin = p.join("bin");
2438 let mut numbered: Vec<(u32, PathBuf)> = Vec::new();
2439 let mut unnumbered: Vec<PathBuf> = Vec::new();
2440 if let Ok(sub) = std::fs::read_dir(&bin) {
2441 for s in sub.flatten() {
2442 let sp = s.path();
2443 if !sp.is_dir() {
2444 continue;
2445 }
2446 match sp
2447 .file_name()
2448 .and_then(|n| n.to_str())
2449 .and_then(|s| s.parse::<u32>().ok())
2450 {
2451 Some(n) => numbered.push((n, sp)),
2452 None => unnumbered.push(sp),
2453 }
2454 }
2455 }
2456 if let Some(target) = target_major {
2458 if let Some(pos) = numbered.iter().position(|(n, _)| *n == target) {
2459 let (_, matched) = numbered.swap_remove(pos);
2460 out.push(matched);
2461 } else {
2462 numbered.sort_by_key(|b| std::cmp::Reverse(b.0));
2464 if let Some(pos) = numbered.iter().position(|(n, _)| *n <= target) {
2465 let (_, matched) = numbered.remove(pos);
2466 out.push(matched);
2467 } else if let Some((_, p)) = numbered.into_iter().next() {
2468 out.push(p);
2469 }
2470 }
2471 } else {
2472 numbered.sort_by_key(|b| std::cmp::Reverse(b.0));
2476 for (_, p) in numbered {
2477 out.push(p);
2478 }
2479 }
2480 out.extend(unnumbered);
2483 }
2484 }
2485 }
2486 }
2487
2488 out
2489}
2490
2491#[cfg(target_os = "windows")]
2497fn ensure_cudnn_on_path(extra_dirs: &[PathBuf]) {
2498 use std::sync::OnceLock;
2499 static DONE: OnceLock<()> = OnceLock::new();
2500 DONE.get_or_init(|| {
2501 let existing = std::env::var("PATH").unwrap_or_default();
2502 let mut prefix = String::new();
2503 for dir in extra_dirs {
2504 if let Some(s) = dir.to_str() {
2505 if !existing.split(';').any(|p| p == s) {
2506 if !prefix.is_empty() {
2507 prefix.push(';');
2508 }
2509 prefix.push_str(s);
2510 }
2511 }
2512 }
2513 if !prefix.is_empty() {
2514 let new_path = if existing.is_empty() {
2515 prefix
2516 } else {
2517 format!("{prefix};{existing}")
2518 };
2519 unsafe {
2526 std::env::set_var("PATH", new_path);
2527 }
2528 }
2529 });
2530}
2531
2532#[cfg(not(target_os = "windows"))]
2533fn ensure_cudnn_on_path(_extra_dirs: &[PathBuf]) {}
2534
2535fn open_cudnn() -> Result<Library, LoaderError> {
2537 let candidates: Vec<&'static str> = cudnn_candidates()
2538 .into_iter()
2539 .map(|s| Box::leak(s.into_boxed_str()) as &'static str)
2540 .collect();
2541 let candidates_leaked: &'static [&'static str] = Box::leak(candidates.into_boxed_slice());
2542
2543 let extra = cudnn_extra_search_dirs();
2546 ensure_cudnn_on_path(&extra);
2547
2548 if let Ok(lib) = Library::open("cudnn", candidates_leaked) {
2550 return Ok(lib);
2551 }
2552
2553 for dir in &extra {
2555 for candidate in candidates_leaked {
2556 let full = dir.join(candidate);
2557 if let Ok(lib) = Library::open_at("cudnn", &full) {
2558 return Ok(lib);
2559 }
2560 }
2561 }
2562
2563 Err(LoaderError::library_not_found_with_search(
2564 "cudnn",
2565 candidates_leaked,
2566 extra.len(),
2567 ))
2568}
2569
2570macro_rules! cudnn_fns {
2571 ($($name:ident as $sym:literal : $pfn:ty);* $(;)?) => {
2572 pub struct Cudnn {
2574 lib: Library,
2575 $($name: OnceLock<$pfn>,)*
2576 }
2577 impl core::fmt::Debug for Cudnn {
2578 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2579 f.debug_struct("Cudnn").field("lib", &self.lib).finish_non_exhaustive()
2580 }
2581 }
2582 impl Cudnn {
2583 $(
2584 #[doc = concat!("Resolve `", $sym, "`.")]
2585 pub fn $name(&self) -> Result<$pfn, LoaderError> {
2586 if let Some(&p) = self.$name.get() { return Ok(p); }
2587 let raw: *mut () = unsafe { self.lib.raw_symbol($sym)? };
2588 let p: $pfn = unsafe { core::mem::transmute_copy::<*mut (), $pfn>(&raw) };
2589 let _ = self.$name.set(p);
2590 Ok(p)
2591 }
2592 )*
2593 fn empty(lib: Library) -> Self {
2594 Self { lib, $($name: OnceLock::new(),)* }
2595 }
2596 }
2597 };
2598}
2599
2600cudnn_fns! {
2601 cudnn_create as "cudnnCreate": PFN_cudnnCreate;
2603 cudnn_destroy as "cudnnDestroy": PFN_cudnnDestroy;
2604 cudnn_set_stream as "cudnnSetStream": PFN_cudnnSetStream;
2605 cudnn_get_version as "cudnnGetVersion": PFN_cudnnGetVersion;
2606 cudnn_get_error_string as "cudnnGetErrorString": PFN_cudnnGetErrorString;
2607 cudnn_create_tensor_descriptor as "cudnnCreateTensorDescriptor": PFN_cudnnCreateTensorDescriptor;
2609 cudnn_destroy_tensor_descriptor as "cudnnDestroyTensorDescriptor": PFN_cudnnDestroyTensorDescriptor;
2610 cudnn_set_tensor_4d_descriptor as "cudnnSetTensor4dDescriptor": PFN_cudnnSetTensor4dDescriptor;
2611 cudnn_create_activation_descriptor as "cudnnCreateActivationDescriptor": PFN_cudnnCreateActivationDescriptor;
2613 cudnn_destroy_activation_descriptor as "cudnnDestroyActivationDescriptor": PFN_cudnnDestroyActivationDescriptor;
2614 cudnn_set_activation_descriptor as "cudnnSetActivationDescriptor": PFN_cudnnSetActivationDescriptor;
2615 cudnn_activation_forward as "cudnnActivationForward": PFN_cudnnActivationForward;
2616 cudnn_create_filter_descriptor as "cudnnCreateFilterDescriptor": PFN_cudnnCreateFilterDescriptor;
2618 cudnn_destroy_filter_descriptor as "cudnnDestroyFilterDescriptor": PFN_cudnnDestroyFilterDescriptor;
2619 cudnn_set_filter_4d_descriptor as "cudnnSetFilter4dDescriptor": PFN_cudnnSetFilter4dDescriptor;
2620 cudnn_create_convolution_descriptor as "cudnnCreateConvolutionDescriptor": PFN_cudnnCreateConvolutionDescriptor;
2621 cudnn_destroy_convolution_descriptor as "cudnnDestroyConvolutionDescriptor": PFN_cudnnDestroyConvolutionDescriptor;
2622 cudnn_set_convolution_2d_descriptor as "cudnnSetConvolution2dDescriptor": PFN_cudnnSetConvolution2dDescriptor;
2623 cudnn_get_convolution_2d_forward_output_dim as "cudnnGetConvolution2dForwardOutputDim": PFN_cudnnGetConvolution2dForwardOutputDim;
2624 cudnn_get_convolution_forward_workspace_size as "cudnnGetConvolutionForwardWorkspaceSize": PFN_cudnnGetConvolutionForwardWorkspaceSize;
2625 cudnn_convolution_forward as "cudnnConvolutionForward": PFN_cudnnConvolutionForward;
2626 cudnn_convolution_backward_data as "cudnnConvolutionBackwardData": PFN_cudnnConvolutionBackwardData;
2627 cudnn_convolution_backward_filter as "cudnnConvolutionBackwardFilter": PFN_cudnnConvolutionBackwardFilter;
2628 cudnn_convolution_backward_bias as "cudnnConvolutionBackwardBias": PFN_cudnnConvolutionBackwardBias;
2629 cudnn_get_convolution_backward_data_workspace_size as "cudnnGetConvolutionBackwardDataWorkspaceSize": PFN_cudnnGetConvolutionBackwardDataWorkspaceSize;
2630 cudnn_get_convolution_backward_filter_workspace_size as "cudnnGetConvolutionBackwardFilterWorkspaceSize": PFN_cudnnGetConvolutionBackwardFilterWorkspaceSize;
2631 cudnn_create_pooling_descriptor as "cudnnCreatePoolingDescriptor": PFN_cudnnCreatePoolingDescriptor;
2633 cudnn_destroy_pooling_descriptor as "cudnnDestroyPoolingDescriptor": PFN_cudnnDestroyPoolingDescriptor;
2634 cudnn_set_pooling_2d_descriptor as "cudnnSetPooling2dDescriptor": PFN_cudnnSetPooling2dDescriptor;
2635 cudnn_pooling_forward as "cudnnPoolingForward": PFN_cudnnPoolingForward;
2636 cudnn_pooling_backward as "cudnnPoolingBackward": PFN_cudnnPoolingBackward;
2637 cudnn_softmax_forward as "cudnnSoftmaxForward": PFN_cudnnSoftmaxForward;
2639 cudnn_softmax_backward as "cudnnSoftmaxBackward": PFN_cudnnSoftmaxBackward;
2640 cudnn_batch_normalization_forward_inference as "cudnnBatchNormalizationForwardInference": PFN_cudnnBatchNormalizationForwardInference;
2642 cudnn_batch_normalization_forward_training as "cudnnBatchNormalizationForwardTraining": PFN_cudnnBatchNormalizationForwardTraining;
2643 cudnn_batch_normalization_backward as "cudnnBatchNormalizationBackward": PFN_cudnnBatchNormalizationBackward;
2644 cudnn_create_op_tensor_descriptor as "cudnnCreateOpTensorDescriptor": PFN_cudnnCreateOpTensorDescriptor;
2646 cudnn_destroy_op_tensor_descriptor as "cudnnDestroyOpTensorDescriptor": PFN_cudnnDestroyOpTensorDescriptor;
2647 cudnn_set_op_tensor_descriptor as "cudnnSetOpTensorDescriptor": PFN_cudnnSetOpTensorDescriptor;
2648 cudnn_op_tensor as "cudnnOpTensor": PFN_cudnnOpTensor;
2649 cudnn_create_reduce_tensor_descriptor as "cudnnCreateReduceTensorDescriptor": PFN_cudnnCreateReduceTensorDescriptor;
2650 cudnn_destroy_reduce_tensor_descriptor as "cudnnDestroyReduceTensorDescriptor": PFN_cudnnDestroyReduceTensorDescriptor;
2651 cudnn_set_reduce_tensor_descriptor as "cudnnSetReduceTensorDescriptor": PFN_cudnnSetReduceTensorDescriptor;
2652 cudnn_get_reduction_workspace_size as "cudnnGetReductionWorkspaceSize": PFN_cudnnGetReductionWorkspaceSize;
2653 cudnn_reduce_tensor as "cudnnReduceTensor": PFN_cudnnReduceTensor;
2654 cudnn_add_tensor as "cudnnAddTensor": PFN_cudnnAddTensor;
2655 cudnn_transform_tensor as "cudnnTransformTensor": PFN_cudnnTransformTensor;
2656 cudnn_scale_tensor as "cudnnScaleTensor": PFN_cudnnScaleTensor;
2657 cudnn_set_tensor as "cudnnSetTensor": PFN_cudnnSetTensor;
2658 cudnn_create_lrn_descriptor as "cudnnCreateLRNDescriptor": PFN_cudnnCreateLRNDescriptor;
2660 cudnn_destroy_lrn_descriptor as "cudnnDestroyLRNDescriptor": PFN_cudnnDestroyLRNDescriptor;
2661 cudnn_set_lrn_descriptor as "cudnnSetLRNDescriptor": PFN_cudnnSetLRNDescriptor;
2662 cudnn_lrn_cross_channel_forward as "cudnnLRNCrossChannelForward": PFN_cudnnLRNCrossChannelForward;
2663 cudnn_create_dropout_descriptor as "cudnnCreateDropoutDescriptor": PFN_cudnnCreateDropoutDescriptor;
2665 cudnn_destroy_dropout_descriptor as "cudnnDestroyDropoutDescriptor": PFN_cudnnDestroyDropoutDescriptor;
2666 cudnn_dropout_get_states_size as "cudnnDropoutGetStatesSize": PFN_cudnnDropoutGetStatesSize;
2667 cudnn_dropout_get_reserve_space_size as "cudnnDropoutGetReserveSpaceSize": PFN_cudnnDropoutGetReserveSpaceSize;
2668 cudnn_set_dropout_descriptor as "cudnnSetDropoutDescriptor": PFN_cudnnSetDropoutDescriptor;
2669 cudnn_dropout_forward as "cudnnDropoutForward": PFN_cudnnDropoutForward;
2670 cudnn_dropout_backward as "cudnnDropoutBackward": PFN_cudnnDropoutBackward;
2671 cudnn_create_rnn_descriptor as "cudnnCreateRNNDescriptor": PFN_cudnnCreateRNNDescriptor;
2673 cudnn_destroy_rnn_descriptor as "cudnnDestroyRNNDescriptor": PFN_cudnnDestroyRNNDescriptor;
2674 cudnn_create_rnn_data_descriptor as "cudnnCreateRNNDataDescriptor": PFN_cudnnCreateRNNDataDescriptor;
2675 cudnn_destroy_rnn_data_descriptor as "cudnnDestroyRNNDataDescriptor": PFN_cudnnDestroyRNNDataDescriptor;
2676 cudnn_rnn_forward as "cudnnRNNForward": PFN_cudnnRNNForward;
2677 cudnn_backend_create_descriptor as "cudnnBackendCreateDescriptor": PFN_cudnnBackendCreateDescriptor;
2679 cudnn_backend_destroy_descriptor as "cudnnBackendDestroyDescriptor": PFN_cudnnBackendDestroyDescriptor;
2680 cudnn_backend_initialize as "cudnnBackendInitialize": PFN_cudnnBackendInitialize;
2681 cudnn_backend_finalize as "cudnnBackendFinalize": PFN_cudnnBackendFinalize;
2682 cudnn_backend_set_attribute as "cudnnBackendSetAttribute": PFN_cudnnBackendSetAttribute;
2683 cudnn_backend_get_attribute as "cudnnBackendGetAttribute": PFN_cudnnBackendGetAttribute;
2684 cudnn_backend_execute as "cudnnBackendExecute": PFN_cudnnBackendExecute;
2685 cudnn_set_tensor_nd_descriptor as "cudnnSetTensorNdDescriptor": PFN_cudnnSetTensorNdDescriptor;
2687 cudnn_get_tensor_nd_descriptor as "cudnnGetTensorNdDescriptor": PFN_cudnnGetTensorNdDescriptor;
2688 cudnn_set_filter_nd_descriptor as "cudnnSetFilterNdDescriptor": PFN_cudnnSetFilterNdDescriptor;
2689 cudnn_set_convolution_nd_descriptor as "cudnnSetConvolutionNdDescriptor": PFN_cudnnSetConvolutionNdDescriptor;
2690 cudnn_set_pooling_nd_descriptor as "cudnnSetPoolingNdDescriptor": PFN_cudnnSetPoolingNdDescriptor;
2691 cudnn_create_ctc_loss_descriptor as "cudnnCreateCTCLossDescriptor": PFN_cudnnCreateCTCLossDescriptor;
2693 cudnn_destroy_ctc_loss_descriptor as "cudnnDestroyCTCLossDescriptor": PFN_cudnnDestroyCTCLossDescriptor;
2694 cudnn_set_ctc_loss_descriptor as "cudnnSetCTCLossDescriptor": PFN_cudnnSetCTCLossDescriptor;
2695 cudnn_get_ctc_loss_workspace_size as "cudnnGetCTCLossWorkspaceSize": PFN_cudnnGetCTCLossWorkspaceSize;
2696 cudnn_ctc_loss as "cudnnCTCLoss": PFN_cudnnCTCLoss;
2697 cudnn_rnn_backward_data_v8 as "cudnnRNNBackwardData_v8": PFN_cudnnRNNBackwardData_v8;
2699 cudnn_rnn_backward_weights_v8 as "cudnnRNNBackwardWeights_v8": PFN_cudnnRNNBackwardWeights_v8;
2700 cudnn_create_spatial_transformer_descriptor as "cudnnCreateSpatialTransformerDescriptor": PFN_cudnnCreateSpatialTransformerDescriptor;
2702 cudnn_destroy_spatial_transformer_descriptor as "cudnnDestroySpatialTransformerDescriptor": PFN_cudnnDestroySpatialTransformerDescriptor;
2703 cudnn_set_spatial_transformer_nd_descriptor as "cudnnSetSpatialTransformerNdDescriptor": PFN_cudnnSetSpatialTransformerNdDescriptor;
2704 cudnn_spatial_tf_grid_generator_forward as "cudnnSpatialTfGridGeneratorForward": PFN_cudnnSpatialTfGridGeneratorForward;
2705 cudnn_spatial_tf_sampler_forward as "cudnnSpatialTfSamplerForward": PFN_cudnnSpatialTfSamplerForward;
2706
2707 cudnn_set_convolution_group_count as "cudnnSetConvolutionGroupCount": PFN_cudnnSetConvolutionGroupCount;
2709 cudnn_get_convolution_group_count as "cudnnGetConvolutionGroupCount": PFN_cudnnGetConvolutionGroupCount;
2710 cudnn_set_convolution_math_type as "cudnnSetConvolutionMathType": PFN_cudnnSetConvolutionMathType;
2711 cudnn_get_convolution_math_type as "cudnnGetConvolutionMathType": PFN_cudnnGetConvolutionMathType;
2712 cudnn_set_convolution_reorder_type as "cudnnSetConvolutionReorderType": PFN_cudnnSetConvolutionReorderType;
2713 cudnn_get_convolution_reorder_type as "cudnnGetConvolutionReorderType": PFN_cudnnGetConvolutionReorderType;
2714 cudnn_reorder_filter_and_bias as "cudnnReorderFilterAndBias": PFN_cudnnReorderFilterAndBias;
2715 cudnn_convolution_bias_activation_forward as "cudnnConvolutionBiasActivationForward": PFN_cudnnConvolutionBiasActivationForward;
2716 cudnn_activation_backward as "cudnnActivationBackward": PFN_cudnnActivationBackward;
2717 cudnn_set_activation_descriptor_swish_beta as "cudnnSetActivationDescriptorSwishBeta": PFN_cudnnSetActivationDescriptorSwishBeta;
2718 cudnn_get_activation_descriptor_swish_beta as "cudnnGetActivationDescriptorSwishBeta": PFN_cudnnGetActivationDescriptorSwishBeta;
2719 cudnn_lrn_cross_channel_backward as "cudnnLRNCrossChannelBackward": PFN_cudnnLRNCrossChannelBackward;
2720 cudnn_divisive_normalization_forward as "cudnnDivisiveNormalizationForward": PFN_cudnnDivisiveNormalizationForward;
2721 cudnn_divisive_normalization_backward as "cudnnDivisiveNormalizationBackward": PFN_cudnnDivisiveNormalizationBackward;
2722 cudnn_get_reduction_indices_size as "cudnnGetReductionIndicesSize": PFN_cudnnGetReductionIndicesSize;
2723 cudnn_set_tensor_4d_descriptor_ex as "cudnnSetTensor4dDescriptorEx": PFN_cudnnSetTensor4dDescriptorEx;
2724 cudnn_get_tensor_4d_descriptor as "cudnnGetTensor4dDescriptor": PFN_cudnnGetTensor4dDescriptor;
2725 cudnn_get_filter_4d_descriptor as "cudnnGetFilter4dDescriptor": PFN_cudnnGetFilter4dDescriptor;
2726 cudnn_get_dropout_descriptor as "cudnnGetDropoutDescriptor": PFN_cudnnGetDropoutDescriptor;
2727 cudnn_restore_dropout_descriptor as "cudnnRestoreDropoutDescriptor": PFN_cudnnRestoreDropoutDescriptor;
2728
2729 cudnn_get_convolution_forward_algorithm_v7 as "cudnnGetConvolutionForwardAlgorithm_v7": PFN_cudnnGetConvolutionForwardAlgorithm_v7;
2731 cudnn_find_convolution_forward_algorithm as "cudnnFindConvolutionForwardAlgorithm": PFN_cudnnFindConvolutionForwardAlgorithm;
2732 cudnn_find_convolution_forward_algorithm_ex as "cudnnFindConvolutionForwardAlgorithmEx": PFN_cudnnFindConvolutionForwardAlgorithmEx;
2733 cudnn_get_convolution_backward_data_algorithm_v7 as "cudnnGetConvolutionBackwardDataAlgorithm_v7": PFN_cudnnGetConvolutionBackwardDataAlgorithm_v7;
2734 cudnn_find_convolution_backward_data_algorithm as "cudnnFindConvolutionBackwardDataAlgorithm": PFN_cudnnFindConvolutionBackwardDataAlgorithm;
2735 cudnn_get_convolution_backward_filter_algorithm_v7 as "cudnnGetConvolutionBackwardFilterAlgorithm_v7": PFN_cudnnGetConvolutionBackwardFilterAlgorithm_v7;
2736 cudnn_find_convolution_backward_filter_algorithm as "cudnnFindConvolutionBackwardFilterAlgorithm": PFN_cudnnFindConvolutionBackwardFilterAlgorithm;
2737
2738 cudnn_get_batch_normalization_forward_training_ex_workspace_size as "cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize": PFN_cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize;
2740 cudnn_get_batch_normalization_backward_ex_workspace_size as "cudnnGetBatchNormalizationBackwardExWorkspaceSize": PFN_cudnnGetBatchNormalizationBackwardExWorkspaceSize;
2741 cudnn_get_batch_normalization_training_ex_reserve_space_size as "cudnnGetBatchNormalizationTrainingExReserveSpaceSize": PFN_cudnnGetBatchNormalizationTrainingExReserveSpaceSize;
2742 cudnn_batch_normalization_forward_training_ex as "cudnnBatchNormalizationForwardTrainingEx": PFN_cudnnBatchNormalizationForwardTrainingEx;
2743 cudnn_batch_normalization_backward_ex as "cudnnBatchNormalizationBackwardEx": PFN_cudnnBatchNormalizationBackwardEx;
2744 cudnn_normalization_forward_inference as "cudnnNormalizationForwardInference": PFN_cudnnNormalizationForwardInference;
2745 cudnn_get_normalization_forward_training_workspace_size as "cudnnGetNormalizationForwardTrainingWorkspaceSize": PFN_cudnnGetNormalizationForwardTrainingWorkspaceSize;
2746 cudnn_get_normalization_backward_workspace_size as "cudnnGetNormalizationBackwardWorkspaceSize": PFN_cudnnGetNormalizationBackwardWorkspaceSize;
2747 cudnn_get_normalization_training_reserve_space_size as "cudnnGetNormalizationTrainingReserveSpaceSize": PFN_cudnnGetNormalizationTrainingReserveSpaceSize;
2748 cudnn_normalization_forward_training as "cudnnNormalizationForwardTraining": PFN_cudnnNormalizationForwardTraining;
2749 cudnn_normalization_backward as "cudnnNormalizationBackward": PFN_cudnnNormalizationBackward;
2750
2751 cudnn_set_rnn_descriptor_v8 as "cudnnSetRNNDescriptor_v8": PFN_cudnnSetRNNDescriptor_v8;
2753 cudnn_build_rnn_dynamic as "cudnnBuildRNNDynamic": PFN_cudnnBuildRNNDynamic;
2754 cudnn_get_rnn_temp_space_sizes as "cudnnGetRNNTempSpaceSizes": PFN_cudnnGetRNNTempSpaceSizes;
2755 cudnn_get_rnn_weight_space_size as "cudnnGetRNNWeightSpaceSize": PFN_cudnnGetRNNWeightSpaceSize;
2756 cudnn_get_rnn_weight_params as "cudnnGetRNNWeightParams": PFN_cudnnGetRNNWeightParams;
2757
2758 cudnn_create_attn_descriptor as "cudnnCreateAttnDescriptor": PFN_cudnnCreateAttnDescriptor;
2760 cudnn_destroy_attn_descriptor as "cudnnDestroyAttnDescriptor": PFN_cudnnDestroyAttnDescriptor;
2761 cudnn_set_attn_descriptor as "cudnnSetAttnDescriptor": PFN_cudnnSetAttnDescriptor;
2762 cudnn_get_multi_head_attn_buffers as "cudnnGetMultiHeadAttnBuffers": PFN_cudnnGetMultiHeadAttnBuffers;
2763 cudnn_get_multi_head_attn_weights as "cudnnGetMultiHeadAttnWeights": PFN_cudnnGetMultiHeadAttnWeights;
2764 cudnn_multi_head_attn_forward as "cudnnMultiHeadAttnForward": PFN_cudnnMultiHeadAttnForward;
2765 cudnn_multi_head_attn_backward_data as "cudnnMultiHeadAttnBackwardData": PFN_cudnnMultiHeadAttnBackwardData;
2766 cudnn_multi_head_attn_backward_weights as "cudnnMultiHeadAttnBackwardWeights": PFN_cudnnMultiHeadAttnBackwardWeights;
2767 cudnn_create_seq_data_descriptor as "cudnnCreateSeqDataDescriptor": PFN_cudnnCreateSeqDataDescriptor;
2768 cudnn_destroy_seq_data_descriptor as "cudnnDestroySeqDataDescriptor": PFN_cudnnDestroySeqDataDescriptor;
2769 cudnn_set_seq_data_descriptor as "cudnnSetSeqDataDescriptor": PFN_cudnnSetSeqDataDescriptor;
2770}
2771
2772pub fn cudnn() -> Result<&'static Cudnn, LoaderError> {
2774 static CUDNN: OnceLock<Cudnn> = OnceLock::new();
2775 if let Some(c) = CUDNN.get() {
2776 return Ok(c);
2777 }
2778 let lib = open_cudnn()?;
2779 let c = Cudnn::empty(lib);
2780 let _ = CUDNN.set(c);
2781 Ok(CUDNN.get().expect("OnceLock set or lost race"))
2782}
2783
2784#[cfg(test)]
2785mod search_dir_tests {
2786 use super::*;
2787
2788 #[test]
2789 fn parse_cuda_major_handles_typical_windows_paths() {
2790 assert_eq!(
2791 parse_cuda_major_from_path(
2792 r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6"
2793 ),
2794 Some(12),
2795 );
2796 assert_eq!(
2797 parse_cuda_major_from_path(
2798 r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin"
2799 ),
2800 Some(11),
2801 );
2802 assert_eq!(parse_cuda_major_from_path("/opt/cuda/v13.0"), Some(13));
2804 }
2805
2806 #[test]
2807 fn parse_cuda_major_ignores_unrelated_v_prefixed_words() {
2808 assert_eq!(
2811 parse_cuda_major_from_path("/usr/local/verbose/cuda"),
2812 None,
2813 );
2814 assert_eq!(parse_cuda_major_from_path(""), None);
2815 assert_eq!(parse_cuda_major_from_path("/usr/local/cuda"), None);
2816 }
2817
2818 #[test]
2819 fn parse_cuda_major_accepts_uppercase_v() {
2820 assert_eq!(
2821 parse_cuda_major_from_path(r"D:\NVIDIA\CUDA\V12.6\bin"),
2822 Some(12),
2823 );
2824 }
2825
2826 #[test]
2831 fn dir_selection_prefers_target_major() {
2832 let numbered: Vec<(u32, &str)> = vec![(11, "/cudnn/bin/11"), (12, "/cudnn/bin/12")];
2835 let target = Some(12u32);
2836
2837 let chosen: Vec<&str> = match target {
2838 Some(t) => numbered
2839 .iter()
2840 .find(|(n, _)| *n == t)
2841 .map(|(_, p)| *p)
2842 .into_iter()
2843 .collect(),
2844 None => numbered.iter().map(|(_, p)| *p).collect(),
2845 };
2846 assert_eq!(chosen, vec!["/cudnn/bin/12"]);
2847 }
2848
2849 #[test]
2850 fn dir_selection_falls_back_to_highest_le_target() {
2851 let mut numbered: Vec<(u32, &str)> = vec![(11, "/cudnn/11"), (12, "/cudnn/12")];
2854 let target = 13u32;
2855
2856 let result = match numbered.iter().position(|(n, _)| *n == target) {
2858 Some(_pos) => unreachable!("no exact match in this scenario"),
2859 None => {
2860 numbered.sort_by_key(|b| std::cmp::Reverse(b.0));
2861 numbered
2862 .iter()
2863 .find(|(n, _)| *n <= target)
2864 .map(|(_, p)| *p)
2865 }
2866 };
2867 assert_eq!(result, Some("/cudnn/12"));
2868 }
2869
2870 #[test]
2871 fn dir_selection_no_signal_is_highest_first() {
2872 let mut numbered: Vec<(u32, &str)> = vec![(11, "/11"), (13, "/13"), (12, "/12")];
2873 numbered.sort_by_key(|b| std::cmp::Reverse(b.0));
2874 let order: Vec<&str> = numbered.iter().map(|(_, p)| *p).collect();
2875 assert_eq!(order, vec!["/13", "/12", "/11"]);
2876 }
2877}