Skip to main content

baracuda_cudnn_sys/
lib.rs

1//! Raw FFI + dynamic loader for NVIDIA cuDNN (classic-API subset).
2//!
3//! `baracuda-cudnn` wraps this with a safe, typed API. Use this crate
4//! directly only if you need a function that the safe layer hasn't
5//! wrapped yet (in which case please file a bug).
6//!
7//! Handles the non-standard cuDNN install location on Windows
8//! (`C:\Program Files\NVIDIA\CUDNN\v<ver>\bin\<cuda-major>`) by probing
9//! it in addition to the usual `baracuda-core::platform` search paths.
10
11#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals)]
12#![warn(missing_debug_implementations)]
13
14use core::ffi::{c_double, c_int, c_void};
15use std::path::PathBuf;
16use std::sync::OnceLock;
17
18use baracuda_core::{Library, LoaderError};
19use baracuda_cuda_sys::runtime::cudaStream_t;
20use baracuda_types::CudaStatus;
21
22/// Opaque handle. Mirrors `cudnnHandle_t`.
23pub type cudnnHandle_t = *mut c_void;
24/// Opaque handle. Mirrors `cudnnTensorDescriptor_t`.
25pub type cudnnTensorDescriptor_t = *mut c_void;
26/// Opaque handle. Mirrors `cudnnActivationDescriptor_t`.
27pub type cudnnActivationDescriptor_t = *mut c_void;
28/// Opaque handle. Mirrors `cudnnFilterDescriptor_t`.
29pub type cudnnFilterDescriptor_t = *mut c_void;
30/// Opaque handle. Mirrors `cudnnConvolutionDescriptor_t`.
31pub type cudnnConvolutionDescriptor_t = *mut c_void;
32/// Opaque handle. Mirrors `cudnnPoolingDescriptor_t`.
33pub type cudnnPoolingDescriptor_t = *mut c_void;
34/// Opaque handle. Mirrors `cudnnLRNDescriptor_t`.
35pub type cudnnLRNDescriptor_t = *mut c_void;
36/// Opaque handle. Mirrors `cudnnOpTensorDescriptor_t`.
37pub type cudnnOpTensorDescriptor_t = *mut c_void;
38/// Opaque handle. Mirrors `cudnnReduceTensorDescriptor_t`.
39pub type cudnnReduceTensorDescriptor_t = *mut c_void;
40/// Opaque handle. Mirrors `cudnnDropoutDescriptor_t`.
41pub type cudnnDropoutDescriptor_t = *mut c_void;
42/// Opaque handle. Mirrors `cudnnCTCLossDescriptor_t`.
43pub type cudnnCTCLossDescriptor_t = *mut c_void;
44/// Opaque handle. Mirrors `cudnnRNNDescriptor_t`.
45pub type cudnnRNNDescriptor_t = *mut c_void;
46/// Opaque handle. Mirrors `cudnnRNNDataDescriptor_t`.
47pub type cudnnRNNDataDescriptor_t = *mut c_void;
48/// Opaque handle. Mirrors `cudnnBackendDescriptor_t`.
49pub type cudnnBackendDescriptor_t = *mut c_void;
50
51/// Forward-convolution algorithm selector.
52#[repr(i32)]
53#[derive(Copy, Clone, Debug, Eq, PartialEq)]
54pub enum cudnnConvolutionFwdAlgo_t {
55    /// Implicit gemm.
56    ImplicitGemm = 0,
57    /// Implicit precomp gemm.
58    ImplicitPrecompGemm = 1,
59    /// Gemm.
60    Gemm = 2,
61    /// Direct.
62    Direct = 3,
63    /// Fft.
64    Fft = 4,
65    /// Fft tiling.
66    FftTiling = 5,
67    /// Winograd.
68    Winograd = 6,
69    /// Winograd nonfused.
70    WinogradNonfused = 7,
71}
72
73/// Convolution cross-correlation vs true-convolution mode.
74#[repr(i32)]
75#[derive(Copy, Clone, Debug, Eq, PartialEq)]
76pub enum cudnnConvolutionMode_t {
77    /// Convolution.
78    Convolution = 0,
79    /// Cross correlation.
80    CrossCorrelation = 1,
81}
82
83/// Enum mirroring `cudnnDataType_t`.
84#[repr(i32)]
85#[derive(Copy, Clone, Debug, Eq, PartialEq)]
86pub enum cudnnDataType_t {
87    /// Float.
88    Float = 0,
89    /// Double.
90    Double = 1,
91    /// Half.
92    Half = 2,
93    /// Int8.
94    Int8 = 3,
95    /// Int32.
96    Int32 = 4,
97    /// Int8x4.
98    Int8x4 = 5,
99    /// Uint8.
100    Uint8 = 6,
101    /// Uint8x4.
102    Uint8x4 = 7,
103    /// Int8x32.
104    Int8x32 = 8,
105    /// B float16.
106    BFloat16 = 9,
107}
108
109/// Enum mirroring `cudnnTensorFormat_t`.
110#[repr(i32)]
111#[derive(Copy, Clone, Debug, Eq, PartialEq)]
112pub enum cudnnTensorFormat_t {
113    /// Nchw.
114    Nchw = 0,
115    /// Nhwc.
116    Nhwc = 1,
117    /// Nchw vect c.
118    NchwVectC = 2,
119}
120
121/// Enum mirroring `cudnnActivationMode_t`.
122#[repr(i32)]
123#[derive(Copy, Clone, Debug, Eq, PartialEq)]
124pub enum cudnnActivationMode_t {
125    /// Sigmoid.
126    Sigmoid = 0,
127    /// Relu.
128    Relu = 1,
129    /// Tanh.
130    Tanh = 2,
131    /// Clipped relu.
132    ClippedRelu = 3,
133    /// Elu.
134    Elu = 4,
135    /// Identity.
136    Identity = 5,
137    /// Swish.
138    Swish = 6,
139}
140
141/// Enum mirroring `cudnnNanPropagation_t`.
142#[repr(i32)]
143#[derive(Copy, Clone, Debug, Eq, PartialEq)]
144pub enum cudnnNanPropagation_t {
145    /// Not propagate nan.
146    NotPropagateNan = 0,
147    /// Propagate nan.
148    PropagateNan = 1,
149}
150
151/// Enum mirroring `cudnnPoolingMode_t`.
152#[repr(i32)]
153#[derive(Copy, Clone, Debug, Eq, PartialEq)]
154pub enum cudnnPoolingMode_t {
155    /// Max.
156    Max = 0,
157    /// Average count include padding.
158    AverageCountIncludePadding = 1,
159    /// Average count exclude padding.
160    AverageCountExcludePadding = 2,
161    /// Max deterministic.
162    MaxDeterministic = 3,
163}
164
165/// Enum mirroring `cudnnSoftmaxAlgorithm_t`.
166#[repr(i32)]
167#[derive(Copy, Clone, Debug, Eq, PartialEq)]
168pub enum cudnnSoftmaxAlgorithm_t {
169    /// Fast.
170    Fast = 0,
171    /// Accurate.
172    Accurate = 1,
173    /// Log.
174    Log = 2,
175}
176
177/// Enum mirroring `cudnnSoftmaxMode_t`.
178#[repr(i32)]
179#[derive(Copy, Clone, Debug, Eq, PartialEq)]
180pub enum cudnnSoftmaxMode_t {
181    /// Instance.
182    Instance = 0,
183    /// Channel.
184    Channel = 1,
185}
186
187/// Enum mirroring `cudnnBatchNormMode_t`.
188#[repr(i32)]
189#[derive(Copy, Clone, Debug, Eq, PartialEq)]
190pub enum cudnnBatchNormMode_t {
191    /// Per activation.
192    PerActivation = 0,
193    /// Spatial.
194    Spatial = 1,
195    /// Spatial persistent.
196    SpatialPersistent = 2,
197}
198
199/// Enum mirroring `cudnnOpTensorOp_t`.
200#[repr(i32)]
201#[derive(Copy, Clone, Debug, Eq, PartialEq)]
202pub enum cudnnOpTensorOp_t {
203    /// Add.
204    Add = 0,
205    /// Mul.
206    Mul = 1,
207    /// Min.
208    Min = 2,
209    /// Max.
210    Max = 3,
211    /// Sqrt.
212    Sqrt = 4,
213    /// Not.
214    Not = 5,
215}
216
217/// Enum mirroring `cudnnReduceTensorOp_t`.
218#[repr(i32)]
219#[derive(Copy, Clone, Debug, Eq, PartialEq)]
220pub enum cudnnReduceTensorOp_t {
221    /// Add.
222    Add = 0,
223    /// Mul.
224    Mul = 1,
225    /// Min.
226    Min = 2,
227    /// Max.
228    Max = 3,
229    /// Amax.
230    Amax = 4,
231    /// Avg.
232    Avg = 5,
233    /// Norm1.
234    Norm1 = 6,
235    /// Norm2.
236    Norm2 = 7,
237    /// Mul no zeros.
238    MulNoZeros = 8,
239}
240
241/// Enum mirroring `cudnnReduceTensorIndices_t`.
242#[repr(i32)]
243#[derive(Copy, Clone, Debug, Eq, PartialEq)]
244pub enum cudnnReduceTensorIndices_t {
245    /// No indices.
246    NoIndices = 0,
247    /// Flattened indices.
248    FlattenedIndices = 1,
249}
250
251/// Enum mirroring `cudnnIndicesType_t`.
252#[repr(i32)]
253#[derive(Copy, Clone, Debug, Eq, PartialEq)]
254pub enum cudnnIndicesType_t {
255    /// U32.
256    U32 = 0,
257    /// U64.
258    U64 = 1,
259    /// U16.
260    U16 = 2,
261    /// U8.
262    U8 = 3,
263}
264
265/// Enum mirroring `cudnnRNNMode_t`.
266#[repr(i32)]
267#[derive(Copy, Clone, Debug, Eq, PartialEq)]
268pub enum cudnnRNNMode_t {
269    /// Relu rnn.
270    ReluRnn = 0,
271    /// Tanh rnn.
272    TanhRnn = 1,
273    /// Lstm.
274    Lstm = 2,
275    /// Gru.
276    Gru = 3,
277}
278
279/// Enum mirroring `cudnnDirectionMode_t`.
280#[repr(i32)]
281#[derive(Copy, Clone, Debug, Eq, PartialEq)]
282pub enum cudnnDirectionMode_t {
283    /// Unidirectional.
284    Unidirectional = 0,
285    /// Bidirectional.
286    Bidirectional = 1,
287}
288
289/// Enum mirroring `cudnnRNNInputMode_t`.
290#[repr(i32)]
291#[derive(Copy, Clone, Debug, Eq, PartialEq)]
292pub enum cudnnRNNInputMode_t {
293    /// Linear input.
294    LinearInput = 0,
295    /// Skip input.
296    SkipInput = 1,
297}
298
299/// Enum mirroring `cudnnRNNAlgo_t`.
300#[repr(i32)]
301#[derive(Copy, Clone, Debug, Eq, PartialEq)]
302pub enum cudnnRNNAlgo_t {
303    /// Standard.
304    Standard = 0,
305    /// Persist static.
306    PersistStatic = 1,
307    /// Persist dynamic.
308    PersistDynamic = 2,
309    /// Persist static small h.
310    PersistStaticSmallH = 3,
311}
312
313/// Enum mirroring `cudnnConvolutionBwdDataAlgo_t`.
314#[repr(i32)]
315#[derive(Copy, Clone, Debug, Eq, PartialEq)]
316pub enum cudnnConvolutionBwdDataAlgo_t {
317    /// Algo0.
318    Algo0 = 0,
319    /// Algo1.
320    Algo1 = 1,
321    /// Fft.
322    Fft = 2,
323    /// Fft tiling.
324    FftTiling = 3,
325    /// Winograd.
326    Winograd = 4,
327    /// Winograd nonfused.
328    WinogradNonfused = 5,
329}
330
331/// Enum mirroring `cudnnConvolutionBwdFilterAlgo_t`.
332#[repr(i32)]
333#[derive(Copy, Clone, Debug, Eq, PartialEq)]
334pub enum cudnnConvolutionBwdFilterAlgo_t {
335    /// Algo0.
336    Algo0 = 0,
337    /// Algo1.
338    Algo1 = 1,
339    /// Fft.
340    Fft = 2,
341    /// Algo3.
342    Algo3 = 3,
343    /// Winograd.
344    Winograd = 4,
345    /// Winograd nonfused.
346    WinogradNonfused = 5,
347    /// Fft tiling.
348    FftTiling = 6,
349}
350
351/// Enum mirroring `cudnnBackendDescriptorType_t`.
352#[repr(i32)]
353#[derive(Copy, Clone, Debug, Eq, PartialEq)]
354pub enum cudnnBackendDescriptorType_t {
355    /// Pointwise descriptor.
356    PointwiseDescriptor = 0,
357    /// Convolution descriptor.
358    ConvolutionDescriptor = 1,
359    /// Engine descriptor.
360    EngineDescriptor = 2,
361    /// Engine cfg descriptor.
362    EngineCfgDescriptor = 3,
363    /// Execution plan descriptor.
364    ExecutionPlanDescriptor = 4,
365    /// Intermediate info descriptor.
366    IntermediateInfoDescriptor = 5,
367    /// Knob choice descriptor.
368    KnobChoiceDescriptor = 6,
369    /// Knob info descriptor.
370    KnobInfoDescriptor = 7,
371    /// Layout info descriptor.
372    LayoutInfoDescriptor = 8,
373    /// Operation convolution forward descriptor.
374    OperationConvolutionForwardDescriptor = 9,
375    /// Operation convolution backward filter descriptor.
376    OperationConvolutionBackwardFilterDescriptor = 10,
377    /// Operation convolution backward data descriptor.
378    OperationConvolutionBackwardDataDescriptor = 11,
379    /// Operation pointwise descriptor.
380    OperationPointwiseDescriptor = 12,
381    /// Operation gen stats descriptor.
382    OperationGenStatsDescriptor = 13,
383    /// Operation reduction descriptor.
384    OperationReductionDescriptor = 14,
385    /// Operation bn finalize statistics descriptor.
386    OperationBnFinalizeStatisticsDescriptor = 15,
387    /// Operation graph descriptor.
388    OperationGraphDescriptor = 16,
389    /// Variant pack descriptor.
390    VariantPackDescriptor = 17,
391    /// Tensor descriptor.
392    TensorDescriptor = 18,
393    /// Matmul descriptor.
394    MatmulDescriptor = 19,
395    /// Operation matmul descriptor.
396    OperationMatmulDescriptor = 20,
397    /// Operation bn bwd weights descriptor.
398    OperationBnBwdWeightsDescriptor = 21,
399    /// Resample descriptor.
400    ResampleDescriptor = 22,
401    /// Operation resample fwd descriptor.
402    OperationResampleFwdDescriptor = 23,
403    /// Operation resample bwd descriptor.
404    OperationResampleBwdDescriptor = 24,
405    /// Operation concat descriptor.
406    OperationConcatDescriptor = 25,
407    /// Operation signal descriptor.
408    OperationSignalDescriptor = 26,
409    /// Operation norm forward descriptor.
410    OperationNormForwardDescriptor = 27,
411    /// Operation norm backward descriptor.
412    OperationNormBackwardDescriptor = 28,
413    /// Operation rng descriptor.
414    OperationRngDescriptor = 30,
415    /// Rng descriptor.
416    RngDescriptor = 31,
417}
418
419/// Enum mirroring `cudnnBackendAttributeName_t`.
420#[repr(i32)]
421#[derive(Copy, Clone, Debug, Eq, PartialEq)]
422pub enum cudnnBackendAttributeName_t {
423    // Just a representative subset — the real enum has ~200 entries.
424    /// Pointwise mode.
425    PointwiseMode = 0,
426    /// Pointwise math prec.
427    PointwiseMathPrec = 1,
428    /// Pointwise nan propagation.
429    PointwiseNanPropagation = 2,
430    /// Pointwise relu lower clip.
431    PointwiseReluLowerClip = 3,
432    /// Pointwise relu upper clip.
433    PointwiseReluUpperClip = 4,
434    /// Pointwise elu alpha.
435    PointwiseEluAlpha = 5,
436    // Tensor descriptor
437    /// Tensor unique id.
438    TensorUniqueId = 100,
439    /// Tensor data type.
440    TensorDataType = 101,
441    /// Tensor byte alignment.
442    TensorByteAlignment = 102,
443    /// Tensor dimensions.
444    TensorDimensions = 103,
445    /// Tensor strides.
446    TensorStrides = 104,
447    // Convolution descriptor
448    /// Convolution comp type.
449    ConvolutionCompType = 200,
450    /// Convolution conv mode.
451    ConvolutionConvMode = 201,
452    /// Convolution dilations.
453    ConvolutionDilations = 202,
454    /// Convolution filter strides.
455    ConvolutionFilterStrides = 203,
456    /// Convolution pre paddings.
457    ConvolutionPrePaddings = 204,
458    /// Convolution post paddings.
459    ConvolutionPostPaddings = 205,
460    /// Convolution spatial dims.
461    ConvolutionSpatialDims = 206,
462    // Operation graph
463    /// Operation graph handle.
464    OperationGraphHandle = 500,
465    /// Operation graph ops.
466    OperationGraphOps = 501,
467    // Execution plan
468    /// Execution plan handle.
469    ExecutionPlanHandle = 600,
470    /// Execution plan engine config.
471    ExecutionPlanEngineConfig = 601,
472    /// Execution plan workspace size.
473    ExecutionPlanWorkspaceSize = 602,
474}
475
476/// Enum mirroring `cudnnBackendAttributeType_t`.
477#[repr(i32)]
478#[derive(Copy, Clone, Debug, Eq, PartialEq)]
479pub enum cudnnBackendAttributeType_t {
480    /// Handle.
481    Handle = 0,
482    /// Data type.
483    DataType = 1,
484    /// Boolean.
485    Boolean = 2,
486    /// Int64.
487    Int64 = 3,
488    /// Float value.
489    FloatValue = 4,
490    /// Double value.
491    DoubleValue = 5,
492    /// Pointwise mode.
493    PointwiseMode = 6,
494    /// Convolution mode.
495    ConvolutionMode = 7,
496    /// Heur mode.
497    HeurMode = 8,
498    /// Knob type.
499    KnobType = 9,
500    /// Nan propagation.
501    NanPropagation = 10,
502    /// Numerical note.
503    NumericalNote = 11,
504    /// Layout type.
505    LayoutType = 12,
506    /// Attrib name.
507    AttribName = 13,
508    /// Pointer t.
509    PointerT = 14,
510    /// Backend descriptor.
511    BackendDescriptor = 15,
512    /// Genstats mode.
513    GenstatsMode = 16,
514    /// Bn finalize stats mode.
515    BnFinalizeStatsMode = 17,
516    /// Reduction operator type.
517    ReductionOperatorType = 18,
518    /// Behavior note.
519    BehaviorNote = 19,
520    /// Tensor reordering mode.
521    TensorReorderingMode = 20,
522    /// Resample mode.
523    ResampleMode = 21,
524    /// Padding mode.
525    PaddingMode = 22,
526    /// Int array.
527    IntArray = 23,
528    /// Rng distribution.
529    RngDistribution = 24,
530    /// Norm mode.
531    NormMode = 25,
532    /// Norm fwd phase.
533    NormFwdPhase = 26,
534    /// Rng normal.
535    RngNormal = 27,
536    /// Rng uniform.
537    RngUniform = 28,
538}
539
540// ---- new enums for v7 algorithm selection / convolution math / norm ------
541
542/// Math type for a convolution descriptor — controls tensor-core usage.
543#[repr(i32)]
544#[derive(Copy, Clone, Debug, Eq, PartialEq)]
545pub enum cudnnMathType_t {
546    /// Default math.
547    DefaultMath = 0,
548    /// Allow tensor-core math (Volta+).
549    TensorOpMath = 1,
550    /// Allow tensor-core math with implicit f16/bf16 down-conversion.
551    TensorOpMathAllowConversion = 2,
552    /// Strict FMA-only math.
553    FmaMath = 3,
554}
555
556/// Filter / bias reorder selector for INT8 quantized inference.
557#[repr(i32)]
558#[derive(Copy, Clone, Debug, Eq, PartialEq)]
559pub enum cudnnReorderType_t {
560    /// Default reorder.
561    DefaultReorder = 0,
562    /// No reorder.
563    NoReorder = 1,
564}
565
566/// Generic-normalization mode (cuDNN 8+).
567#[repr(i32)]
568#[derive(Copy, Clone, Debug, Eq, PartialEq)]
569pub enum cudnnNormMode_t {
570    /// Per activation.
571    PerActivation = 0,
572    /// Per channel.
573    PerChannel = 1,
574}
575
576/// Generic-normalization algorithm.
577#[repr(i32)]
578#[derive(Copy, Clone, Debug, Eq, PartialEq)]
579pub enum cudnnNormAlgo_t {
580    /// Standard.
581    Standard = 0,
582    /// Persist.
583    Persist = 1,
584}
585
586/// Optional fused op for normalization (None / Activation / Add+Activation).
587#[repr(i32)]
588#[derive(Copy, Clone, Debug, Eq, PartialEq)]
589pub enum cudnnNormOps_t {
590    /// Norm.
591    Norm = 0,
592    /// Norm activation.
593    NormActivation = 1,
594    /// Norm add activation.
595    NormAddActivation = 2,
596}
597
598/// Optional fused op for batch-normalization Ex (mirrors cudnnBatchNormOps_t).
599#[repr(i32)]
600#[derive(Copy, Clone, Debug, Eq, PartialEq)]
601pub enum cudnnBatchNormOps_t {
602    /// Bn.
603    Bn = 0,
604    /// Bn activation.
605    BnActivation = 1,
606    /// Bn add activation.
607    BnAddActivation = 2,
608}
609
610/// `cudnnDeterminism_t` — distinguishes deterministic vs non-deterministic
611/// algorithm choices in `*AlgoPerf_t`.
612#[repr(i32)]
613#[derive(Copy, Clone, Debug, Eq, PartialEq)]
614pub enum cudnnDeterminism_t {
615    /// Non deterministic.
616    NonDeterministic = 0,
617    /// Deterministic.
618    Deterministic = 1,
619}
620
621/// Result row from `cudnnFindConvolutionForwardAlgorithm` /
622/// `cudnnGetConvolutionForwardAlgorithm_v7`.
623#[repr(C)]
624#[derive(Copy, Clone, Debug)]
625pub struct cudnnConvolutionFwdAlgoPerf_t {
626    /// Algo field.
627    pub algo: cudnnConvolutionFwdAlgo_t,
628    /// Status field.
629    pub status: cudnnStatus_t,
630    /// Time field.
631    pub time: f32,
632    /// Memory field.
633    pub memory: usize,
634    /// Determinism field.
635    pub determinism: cudnnDeterminism_t,
636    /// Math type field.
637    pub math_type: cudnnMathType_t,
638    /// Reserved padding; do not use.
639    pub reserved: [c_int; 3],
640}
641
642/// Algorithm-finder performance row. Mirrors `cudnnConvolutionBwdDataAlgoPerf_t`.
643#[repr(C)]
644#[derive(Copy, Clone, Debug)]
645pub struct cudnnConvolutionBwdDataAlgoPerf_t {
646    /// Algo field.
647    pub algo: cudnnConvolutionBwdDataAlgo_t,
648    /// Status field.
649    pub status: cudnnStatus_t,
650    /// Time field.
651    pub time: f32,
652    /// Memory field.
653    pub memory: usize,
654    /// Determinism field.
655    pub determinism: cudnnDeterminism_t,
656    /// Math type field.
657    pub math_type: cudnnMathType_t,
658    /// Reserved padding; do not use.
659    pub reserved: [c_int; 3],
660}
661
662/// Algorithm-finder performance row. Mirrors `cudnnConvolutionBwdFilterAlgoPerf_t`.
663#[repr(C)]
664#[derive(Copy, Clone, Debug)]
665pub struct cudnnConvolutionBwdFilterAlgoPerf_t {
666    /// Algo field.
667    pub algo: cudnnConvolutionBwdFilterAlgo_t,
668    /// Status field.
669    pub status: cudnnStatus_t,
670    /// Time field.
671    pub time: f32,
672    /// Memory field.
673    pub memory: usize,
674    /// Determinism field.
675    pub determinism: cudnnDeterminism_t,
676    /// Math type field.
677    pub math_type: cudnnMathType_t,
678    /// Reserved padding; do not use.
679    pub reserved: [c_int; 3],
680}
681
682// ---- new opaque descriptors --------------------------------------------------
683
684/// Opaque handle. Mirrors `cudnnTensorTransformDescriptor_t`.
685pub type cudnnTensorTransformDescriptor_t = *mut c_void;
686/// Opaque handle. Mirrors `cudnnAttnDescriptor_t`.
687pub type cudnnAttnDescriptor_t = *mut c_void;
688/// Opaque handle. Mirrors `cudnnSeqDataDescriptor_t`.
689pub type cudnnSeqDataDescriptor_t = *mut c_void;
690
691// ---- status ---------------------------------------------------------------
692
693/// Struct mirroring `cudnnStatus_t`.
694#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
695#[repr(transparent)]
696pub struct cudnnStatus_t(pub i32);
697
698impl cudnnStatus_t {
699    /// Success.
700    pub const SUCCESS: Self = Self(0);
701    /// Not initialized.
702    pub const NOT_INITIALIZED: Self = Self(1);
703    /// Alloc failed.
704    pub const ALLOC_FAILED: Self = Self(2);
705    /// Bad param.
706    pub const BAD_PARAM: Self = Self(3);
707    /// Internal error.
708    pub const INTERNAL_ERROR: Self = Self(4);
709    /// Invalid value.
710    pub const INVALID_VALUE: Self = Self(5);
711    /// Arch mismatch.
712    pub const ARCH_MISMATCH: Self = Self(6);
713    /// Mapping error.
714    pub const MAPPING_ERROR: Self = Self(7);
715    /// Execution failed.
716    pub const EXECUTION_FAILED: Self = Self(8);
717    /// Not supported.
718    pub const NOT_SUPPORTED: Self = Self(9);
719    /// License error.
720    pub const LICENSE_ERROR: Self = Self(10);
721
722    /// Is success.
723    pub const fn is_success(self) -> bool {
724        self.0 == 0
725    }
726}
727
728impl CudaStatus for cudnnStatus_t {
729    fn code(self) -> i32 {
730        self.0
731    }
732    fn name(self) -> &'static str {
733        match self.0 {
734            0 => "CUDNN_STATUS_SUCCESS",
735            1 => "CUDNN_STATUS_NOT_INITIALIZED",
736            2 => "CUDNN_STATUS_ALLOC_FAILED",
737            3 => "CUDNN_STATUS_BAD_PARAM",
738            4 => "CUDNN_STATUS_INTERNAL_ERROR",
739            5 => "CUDNN_STATUS_INVALID_VALUE",
740            6 => "CUDNN_STATUS_ARCH_MISMATCH",
741            8 => "CUDNN_STATUS_EXECUTION_FAILED",
742            9 => "CUDNN_STATUS_NOT_SUPPORTED",
743            _ => "CUDNN_STATUS_UNRECOGNIZED",
744        }
745    }
746    fn description(self) -> &'static str {
747        match self.0 {
748            0 => "success",
749            1 => "cuDNN not initialized",
750            3 => "bad parameter",
751            9 => "operation not supported on this device/version",
752            _ => "unrecognized cuDNN status code",
753        }
754    }
755    fn is_success(self) -> bool {
756        cudnnStatus_t::is_success(self)
757    }
758    fn library(self) -> &'static str {
759        "cudnn"
760    }
761}
762
763// ---- function-pointer types ----------------------------------------------
764
765/// cuDNN: create. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
766pub type PFN_cudnnCreate = unsafe extern "C" fn(handle: *mut cudnnHandle_t) -> cudnnStatus_t;
767/// cuDNN: destroy. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
768pub type PFN_cudnnDestroy = unsafe extern "C" fn(handle: cudnnHandle_t) -> cudnnStatus_t;
769/// cuDNN: set stream. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
770pub type PFN_cudnnSetStream =
771    unsafe extern "C" fn(handle: cudnnHandle_t, stream: cudaStream_t) -> cudnnStatus_t;
772/// cuDNN: get version. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
773pub type PFN_cudnnGetVersion = unsafe extern "C" fn() -> usize;
774
775/// cuDNN: create tensor descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
776pub type PFN_cudnnCreateTensorDescriptor =
777    unsafe extern "C" fn(desc: *mut cudnnTensorDescriptor_t) -> cudnnStatus_t;
778/// cuDNN: destroy tensor descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
779pub type PFN_cudnnDestroyTensorDescriptor =
780    unsafe extern "C" fn(desc: cudnnTensorDescriptor_t) -> cudnnStatus_t;
781/// cuDNN: set tensor4d descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
782pub type PFN_cudnnSetTensor4dDescriptor = unsafe extern "C" fn(
783    desc: cudnnTensorDescriptor_t,
784    format: cudnnTensorFormat_t,
785    data_type: cudnnDataType_t,
786    n: c_int,
787    c: c_int,
788    h: c_int,
789    w: c_int,
790) -> cudnnStatus_t;
791
792/// cuDNN: create activation descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
793pub type PFN_cudnnCreateActivationDescriptor =
794    unsafe extern "C" fn(desc: *mut cudnnActivationDescriptor_t) -> cudnnStatus_t;
795/// cuDNN: destroy activation descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
796pub type PFN_cudnnDestroyActivationDescriptor =
797    unsafe extern "C" fn(desc: cudnnActivationDescriptor_t) -> cudnnStatus_t;
798/// cuDNN: set activation descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
799pub type PFN_cudnnSetActivationDescriptor = unsafe extern "C" fn(
800    desc: cudnnActivationDescriptor_t,
801    mode: cudnnActivationMode_t,
802    nan_prop: cudnnNanPropagation_t,
803    coef: c_double,
804) -> cudnnStatus_t;
805
806/// cuDNN: activation forward. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
807pub type PFN_cudnnActivationForward = unsafe extern "C" fn(
808    handle: cudnnHandle_t,
809    activation_desc: cudnnActivationDescriptor_t,
810    alpha: *const c_void,
811    x_desc: cudnnTensorDescriptor_t,
812    x: *const c_void,
813    beta: *const c_void,
814    y_desc: cudnnTensorDescriptor_t,
815    y: *mut c_void,
816) -> cudnnStatus_t;
817
818// ---- convolution ----------------------------------------------------------
819
820/// cuDNN: create filter descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
821pub type PFN_cudnnCreateFilterDescriptor =
822    unsafe extern "C" fn(desc: *mut cudnnFilterDescriptor_t) -> cudnnStatus_t;
823/// cuDNN: destroy filter descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
824pub type PFN_cudnnDestroyFilterDescriptor =
825    unsafe extern "C" fn(desc: cudnnFilterDescriptor_t) -> cudnnStatus_t;
826/// cuDNN: set filter4d descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
827pub type PFN_cudnnSetFilter4dDescriptor = unsafe extern "C" fn(
828    desc: cudnnFilterDescriptor_t,
829    data_type: cudnnDataType_t,
830    format: cudnnTensorFormat_t,
831    k: c_int,
832    c: c_int,
833    h: c_int,
834    w: c_int,
835) -> cudnnStatus_t;
836
837/// cuDNN: create convolution descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
838pub type PFN_cudnnCreateConvolutionDescriptor =
839    unsafe extern "C" fn(desc: *mut cudnnConvolutionDescriptor_t) -> cudnnStatus_t;
840/// cuDNN: destroy convolution descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
841pub type PFN_cudnnDestroyConvolutionDescriptor =
842    unsafe extern "C" fn(desc: cudnnConvolutionDescriptor_t) -> cudnnStatus_t;
843/// cuDNN: set convolution2d descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
844pub type PFN_cudnnSetConvolution2dDescriptor = unsafe extern "C" fn(
845    desc: cudnnConvolutionDescriptor_t,
846    pad_h: c_int,
847    pad_w: c_int,
848    u: c_int,
849    v: c_int,
850    dilation_h: c_int,
851    dilation_w: c_int,
852    mode: cudnnConvolutionMode_t,
853    compute_type: cudnnDataType_t,
854) -> cudnnStatus_t;
855
856/// cuDNN: get convolution2d forward output dim. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
857pub type PFN_cudnnGetConvolution2dForwardOutputDim = unsafe extern "C" fn(
858    conv_desc: cudnnConvolutionDescriptor_t,
859    input_desc: cudnnTensorDescriptor_t,
860    filter_desc: cudnnFilterDescriptor_t,
861    n: *mut c_int,
862    c: *mut c_int,
863    h: *mut c_int,
864    w: *mut c_int,
865) -> cudnnStatus_t;
866
867/// cuDNN: get convolution forward workspace size. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
868pub type PFN_cudnnGetConvolutionForwardWorkspaceSize = unsafe extern "C" fn(
869    handle: cudnnHandle_t,
870    x_desc: cudnnTensorDescriptor_t,
871    w_desc: cudnnFilterDescriptor_t,
872    conv_desc: cudnnConvolutionDescriptor_t,
873    y_desc: cudnnTensorDescriptor_t,
874    algo: cudnnConvolutionFwdAlgo_t,
875    size_in_bytes: *mut usize,
876) -> cudnnStatus_t;
877
878/// cuDNN: convolution forward. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
879pub type PFN_cudnnConvolutionForward = unsafe extern "C" fn(
880    handle: cudnnHandle_t,
881    alpha: *const c_void,
882    x_desc: cudnnTensorDescriptor_t,
883    x: *const c_void,
884    w_desc: cudnnFilterDescriptor_t,
885    w: *const c_void,
886    conv_desc: cudnnConvolutionDescriptor_t,
887    algo: cudnnConvolutionFwdAlgo_t,
888    workspace: *mut c_void,
889    workspace_size: usize,
890    beta: *const c_void,
891    y_desc: cudnnTensorDescriptor_t,
892    y: *mut c_void,
893) -> cudnnStatus_t;
894
895/// cuDNN: convolution backward data. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
896pub type PFN_cudnnConvolutionBackwardData = unsafe extern "C" fn(
897    handle: cudnnHandle_t,
898    alpha: *const c_void,
899    w_desc: cudnnFilterDescriptor_t,
900    w: *const c_void,
901    dy_desc: cudnnTensorDescriptor_t,
902    dy: *const c_void,
903    conv_desc: cudnnConvolutionDescriptor_t,
904    algo: cudnnConvolutionBwdDataAlgo_t,
905    workspace: *mut c_void,
906    workspace_size: usize,
907    beta: *const c_void,
908    dx_desc: cudnnTensorDescriptor_t,
909    dx: *mut c_void,
910) -> cudnnStatus_t;
911
912/// cuDNN: convolution backward filter. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
913pub type PFN_cudnnConvolutionBackwardFilter = unsafe extern "C" fn(
914    handle: cudnnHandle_t,
915    alpha: *const c_void,
916    x_desc: cudnnTensorDescriptor_t,
917    x: *const c_void,
918    dy_desc: cudnnTensorDescriptor_t,
919    dy: *const c_void,
920    conv_desc: cudnnConvolutionDescriptor_t,
921    algo: cudnnConvolutionBwdFilterAlgo_t,
922    workspace: *mut c_void,
923    workspace_size: usize,
924    beta: *const c_void,
925    dw_desc: cudnnFilterDescriptor_t,
926    dw: *mut c_void,
927) -> cudnnStatus_t;
928
929/// cuDNN: convolution backward bias. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
930pub type PFN_cudnnConvolutionBackwardBias = unsafe extern "C" fn(
931    handle: cudnnHandle_t,
932    alpha: *const c_void,
933    dy_desc: cudnnTensorDescriptor_t,
934    dy: *const c_void,
935    beta: *const c_void,
936    db_desc: cudnnTensorDescriptor_t,
937    db: *mut c_void,
938) -> cudnnStatus_t;
939
940/// cuDNN: get convolution backward data workspace size. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
941pub type PFN_cudnnGetConvolutionBackwardDataWorkspaceSize = unsafe extern "C" fn(
942    handle: cudnnHandle_t,
943    w_desc: cudnnFilterDescriptor_t,
944    dy_desc: cudnnTensorDescriptor_t,
945    conv_desc: cudnnConvolutionDescriptor_t,
946    dx_desc: cudnnTensorDescriptor_t,
947    algo: cudnnConvolutionBwdDataAlgo_t,
948    size_in_bytes: *mut usize,
949) -> cudnnStatus_t;
950
951/// cuDNN: get convolution backward filter workspace size. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
952pub type PFN_cudnnGetConvolutionBackwardFilterWorkspaceSize = unsafe extern "C" fn(
953    handle: cudnnHandle_t,
954    x_desc: cudnnTensorDescriptor_t,
955    dy_desc: cudnnTensorDescriptor_t,
956    conv_desc: cudnnConvolutionDescriptor_t,
957    dw_desc: cudnnFilterDescriptor_t,
958    algo: cudnnConvolutionBwdFilterAlgo_t,
959    size_in_bytes: *mut usize,
960) -> cudnnStatus_t;
961
962// ---- pooling --------------------------------------------------------------
963
964/// cuDNN: create pooling descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
965pub type PFN_cudnnCreatePoolingDescriptor =
966    unsafe extern "C" fn(desc: *mut cudnnPoolingDescriptor_t) -> cudnnStatus_t;
967/// cuDNN: destroy pooling descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
968pub type PFN_cudnnDestroyPoolingDescriptor =
969    unsafe extern "C" fn(desc: cudnnPoolingDescriptor_t) -> cudnnStatus_t;
970/// cuDNN: set pooling2d descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
971pub type PFN_cudnnSetPooling2dDescriptor = unsafe extern "C" fn(
972    desc: cudnnPoolingDescriptor_t,
973    mode: cudnnPoolingMode_t,
974    nan_prop: cudnnNanPropagation_t,
975    window_h: c_int,
976    window_w: c_int,
977    vertical_padding: c_int,
978    horizontal_padding: c_int,
979    vertical_stride: c_int,
980    horizontal_stride: c_int,
981) -> cudnnStatus_t;
982/// cuDNN: pooling forward. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
983pub type PFN_cudnnPoolingForward = unsafe extern "C" fn(
984    handle: cudnnHandle_t,
985    pool_desc: cudnnPoolingDescriptor_t,
986    alpha: *const c_void,
987    x_desc: cudnnTensorDescriptor_t,
988    x: *const c_void,
989    beta: *const c_void,
990    y_desc: cudnnTensorDescriptor_t,
991    y: *mut c_void,
992) -> cudnnStatus_t;
993/// cuDNN: pooling backward. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
994pub type PFN_cudnnPoolingBackward = unsafe extern "C" fn(
995    handle: cudnnHandle_t,
996    pool_desc: cudnnPoolingDescriptor_t,
997    alpha: *const c_void,
998    y_desc: cudnnTensorDescriptor_t,
999    y: *const c_void,
1000    dy_desc: cudnnTensorDescriptor_t,
1001    dy: *const c_void,
1002    x_desc: cudnnTensorDescriptor_t,
1003    x: *const c_void,
1004    beta: *const c_void,
1005    dx_desc: cudnnTensorDescriptor_t,
1006    dx: *mut c_void,
1007) -> cudnnStatus_t;
1008
1009// ---- softmax --------------------------------------------------------------
1010
1011/// cuDNN: softmax forward. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1012pub type PFN_cudnnSoftmaxForward = unsafe extern "C" fn(
1013    handle: cudnnHandle_t,
1014    algo: cudnnSoftmaxAlgorithm_t,
1015    mode: cudnnSoftmaxMode_t,
1016    alpha: *const c_void,
1017    x_desc: cudnnTensorDescriptor_t,
1018    x: *const c_void,
1019    beta: *const c_void,
1020    y_desc: cudnnTensorDescriptor_t,
1021    y: *mut c_void,
1022) -> cudnnStatus_t;
1023
1024/// cuDNN: softmax backward. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1025pub type PFN_cudnnSoftmaxBackward = unsafe extern "C" fn(
1026    handle: cudnnHandle_t,
1027    algo: cudnnSoftmaxAlgorithm_t,
1028    mode: cudnnSoftmaxMode_t,
1029    alpha: *const c_void,
1030    y_desc: cudnnTensorDescriptor_t,
1031    y: *const c_void,
1032    dy_desc: cudnnTensorDescriptor_t,
1033    dy: *const c_void,
1034    beta: *const c_void,
1035    dx_desc: cudnnTensorDescriptor_t,
1036    dx: *mut c_void,
1037) -> cudnnStatus_t;
1038
1039// ---- batch normalization --------------------------------------------------
1040
1041/// cuDNN: batch normalization forward inference. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1042pub type PFN_cudnnBatchNormalizationForwardInference = unsafe extern "C" fn(
1043    handle: cudnnHandle_t,
1044    mode: cudnnBatchNormMode_t,
1045    alpha: *const c_void,
1046    beta: *const c_void,
1047    x_desc: cudnnTensorDescriptor_t,
1048    x: *const c_void,
1049    y_desc: cudnnTensorDescriptor_t,
1050    y: *mut c_void,
1051    bn_scale_bias_mean_var_desc: cudnnTensorDescriptor_t,
1052    bn_scale: *const c_void,
1053    bn_bias: *const c_void,
1054    estimated_mean: *const c_void,
1055    estimated_variance: *const c_void,
1056    epsilon: c_double,
1057) -> cudnnStatus_t;
1058
1059/// cuDNN: batch normalization forward training. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1060pub type PFN_cudnnBatchNormalizationForwardTraining = unsafe extern "C" fn(
1061    handle: cudnnHandle_t,
1062    mode: cudnnBatchNormMode_t,
1063    alpha: *const c_void,
1064    beta: *const c_void,
1065    x_desc: cudnnTensorDescriptor_t,
1066    x: *const c_void,
1067    y_desc: cudnnTensorDescriptor_t,
1068    y: *mut c_void,
1069    bn_scale_bias_mean_var_desc: cudnnTensorDescriptor_t,
1070    bn_scale: *const c_void,
1071    bn_bias: *const c_void,
1072    exponential_average_factor: c_double,
1073    result_running_mean: *mut c_void,
1074    result_running_variance: *mut c_void,
1075    epsilon: c_double,
1076    result_save_mean: *mut c_void,
1077    result_save_inv_variance: *mut c_void,
1078) -> cudnnStatus_t;
1079
1080/// cuDNN: batch normalization backward. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1081pub type PFN_cudnnBatchNormalizationBackward = unsafe extern "C" fn(
1082    handle: cudnnHandle_t,
1083    mode: cudnnBatchNormMode_t,
1084    alpha_data_diff: *const c_void,
1085    beta_data_diff: *const c_void,
1086    alpha_param_diff: *const c_void,
1087    beta_param_diff: *const c_void,
1088    x_desc: cudnnTensorDescriptor_t,
1089    x: *const c_void,
1090    dy_desc: cudnnTensorDescriptor_t,
1091    dy: *const c_void,
1092    dx_desc: cudnnTensorDescriptor_t,
1093    dx: *mut c_void,
1094    bn_scale_bias_diff_desc: cudnnTensorDescriptor_t,
1095    bn_scale: *const c_void,
1096    bn_scale_result: *mut c_void,
1097    bn_bias_result: *mut c_void,
1098    epsilon: c_double,
1099    saved_mean: *const c_void,
1100    saved_inv_variance: *const c_void,
1101) -> cudnnStatus_t;
1102
1103// ---- op-tensor / reduce / transform --------------------------------------
1104
1105/// cuDNN: create op tensor descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1106pub type PFN_cudnnCreateOpTensorDescriptor =
1107    unsafe extern "C" fn(desc: *mut cudnnOpTensorDescriptor_t) -> cudnnStatus_t;
1108/// cuDNN: destroy op tensor descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1109pub type PFN_cudnnDestroyOpTensorDescriptor =
1110    unsafe extern "C" fn(desc: cudnnOpTensorDescriptor_t) -> cudnnStatus_t;
1111/// cuDNN: set op tensor descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1112pub type PFN_cudnnSetOpTensorDescriptor = unsafe extern "C" fn(
1113    desc: cudnnOpTensorDescriptor_t,
1114    op: cudnnOpTensorOp_t,
1115    compute_type: cudnnDataType_t,
1116    nan_prop: cudnnNanPropagation_t,
1117) -> cudnnStatus_t;
1118/// cuDNN: op tensor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1119pub type PFN_cudnnOpTensor = unsafe extern "C" fn(
1120    handle: cudnnHandle_t,
1121    desc: cudnnOpTensorDescriptor_t,
1122    alpha1: *const c_void,
1123    a_desc: cudnnTensorDescriptor_t,
1124    a: *const c_void,
1125    alpha2: *const c_void,
1126    b_desc: cudnnTensorDescriptor_t,
1127    b: *const c_void,
1128    beta: *const c_void,
1129    c_desc: cudnnTensorDescriptor_t,
1130    c: *mut c_void,
1131) -> cudnnStatus_t;
1132
1133/// cuDNN: create reduce tensor descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1134pub type PFN_cudnnCreateReduceTensorDescriptor =
1135    unsafe extern "C" fn(desc: *mut cudnnReduceTensorDescriptor_t) -> cudnnStatus_t;
1136/// cuDNN: destroy reduce tensor descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1137pub type PFN_cudnnDestroyReduceTensorDescriptor =
1138    unsafe extern "C" fn(desc: cudnnReduceTensorDescriptor_t) -> cudnnStatus_t;
1139/// cuDNN: set reduce tensor descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1140pub type PFN_cudnnSetReduceTensorDescriptor = unsafe extern "C" fn(
1141    desc: cudnnReduceTensorDescriptor_t,
1142    op: cudnnReduceTensorOp_t,
1143    compute_type: cudnnDataType_t,
1144    nan_prop: cudnnNanPropagation_t,
1145    indices: cudnnReduceTensorIndices_t,
1146    indices_type: cudnnIndicesType_t,
1147) -> cudnnStatus_t;
1148/// cuDNN: get reduction workspace size. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1149pub type PFN_cudnnGetReductionWorkspaceSize = unsafe extern "C" fn(
1150    handle: cudnnHandle_t,
1151    desc: cudnnReduceTensorDescriptor_t,
1152    a_desc: cudnnTensorDescriptor_t,
1153    c_desc: cudnnTensorDescriptor_t,
1154    workspace_size: *mut usize,
1155) -> cudnnStatus_t;
1156/// cuDNN: reduce tensor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1157pub type PFN_cudnnReduceTensor = unsafe extern "C" fn(
1158    handle: cudnnHandle_t,
1159    desc: cudnnReduceTensorDescriptor_t,
1160    indices: *mut c_void,
1161    indices_size: usize,
1162    workspace: *mut c_void,
1163    workspace_size: usize,
1164    alpha: *const c_void,
1165    a_desc: cudnnTensorDescriptor_t,
1166    a: *const c_void,
1167    beta: *const c_void,
1168    c_desc: cudnnTensorDescriptor_t,
1169    c: *mut c_void,
1170) -> cudnnStatus_t;
1171
1172/// cuDNN: add tensor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1173pub type PFN_cudnnAddTensor = unsafe extern "C" fn(
1174    handle: cudnnHandle_t,
1175    alpha: *const c_void,
1176    a_desc: cudnnTensorDescriptor_t,
1177    a: *const c_void,
1178    beta: *const c_void,
1179    c_desc: cudnnTensorDescriptor_t,
1180    c: *mut c_void,
1181) -> cudnnStatus_t;
1182
1183/// cuDNN: transform tensor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1184pub type PFN_cudnnTransformTensor = unsafe extern "C" fn(
1185    handle: cudnnHandle_t,
1186    alpha: *const c_void,
1187    x_desc: cudnnTensorDescriptor_t,
1188    x: *const c_void,
1189    beta: *const c_void,
1190    y_desc: cudnnTensorDescriptor_t,
1191    y: *mut c_void,
1192) -> cudnnStatus_t;
1193
1194/// cuDNN: scale tensor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1195pub type PFN_cudnnScaleTensor = unsafe extern "C" fn(
1196    handle: cudnnHandle_t,
1197    y_desc: cudnnTensorDescriptor_t,
1198    y: *mut c_void,
1199    alpha: *const c_void,
1200) -> cudnnStatus_t;
1201
1202/// cuDNN: set tensor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1203pub type PFN_cudnnSetTensor = unsafe extern "C" fn(
1204    handle: cudnnHandle_t,
1205    y_desc: cudnnTensorDescriptor_t,
1206    y: *mut c_void,
1207    value_ptr: *const c_void,
1208) -> cudnnStatus_t;
1209
1210// ---- LRN ------------------------------------------------------------------
1211
1212/// cuDNN: create LRN descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1213pub type PFN_cudnnCreateLRNDescriptor =
1214    unsafe extern "C" fn(desc: *mut cudnnLRNDescriptor_t) -> cudnnStatus_t;
1215/// cuDNN: destroy LRN descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1216pub type PFN_cudnnDestroyLRNDescriptor =
1217    unsafe extern "C" fn(desc: cudnnLRNDescriptor_t) -> cudnnStatus_t;
1218/// cuDNN: set LRN descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1219pub type PFN_cudnnSetLRNDescriptor = unsafe extern "C" fn(
1220    desc: cudnnLRNDescriptor_t,
1221    lrn_n: c_int,
1222    lrn_alpha: c_double,
1223    lrn_beta: c_double,
1224    lrn_k: c_double,
1225) -> cudnnStatus_t;
1226/// cuDNN: LRN cross channel forward. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1227pub type PFN_cudnnLRNCrossChannelForward = unsafe extern "C" fn(
1228    handle: cudnnHandle_t,
1229    lrn_desc: cudnnLRNDescriptor_t,
1230    lrn_mode: c_int,
1231    alpha: *const c_void,
1232    x_desc: cudnnTensorDescriptor_t,
1233    x: *const c_void,
1234    beta: *const c_void,
1235    y_desc: cudnnTensorDescriptor_t,
1236    y: *mut c_void,
1237) -> cudnnStatus_t;
1238
1239// ---- dropout --------------------------------------------------------------
1240
1241/// cuDNN: create dropout descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1242pub type PFN_cudnnCreateDropoutDescriptor =
1243    unsafe extern "C" fn(desc: *mut cudnnDropoutDescriptor_t) -> cudnnStatus_t;
1244/// cuDNN: destroy dropout descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1245pub type PFN_cudnnDestroyDropoutDescriptor =
1246    unsafe extern "C" fn(desc: cudnnDropoutDescriptor_t) -> cudnnStatus_t;
1247/// cuDNN: dropout get states size. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1248pub type PFN_cudnnDropoutGetStatesSize =
1249    unsafe extern "C" fn(handle: cudnnHandle_t, size_in_bytes: *mut usize) -> cudnnStatus_t;
1250/// cuDNN: dropout get reserve space size. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1251pub type PFN_cudnnDropoutGetReserveSpaceSize = unsafe extern "C" fn(
1252    x_desc: cudnnTensorDescriptor_t,
1253    size_in_bytes: *mut usize,
1254) -> cudnnStatus_t;
1255/// cuDNN: set dropout descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1256pub type PFN_cudnnSetDropoutDescriptor = unsafe extern "C" fn(
1257    desc: cudnnDropoutDescriptor_t,
1258    handle: cudnnHandle_t,
1259    dropout: f32,
1260    states: *mut c_void,
1261    state_size: usize,
1262    seed: u64,
1263) -> cudnnStatus_t;
1264/// cuDNN: dropout forward. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1265pub type PFN_cudnnDropoutForward = unsafe extern "C" fn(
1266    handle: cudnnHandle_t,
1267    desc: cudnnDropoutDescriptor_t,
1268    x_desc: cudnnTensorDescriptor_t,
1269    x: *const c_void,
1270    y_desc: cudnnTensorDescriptor_t,
1271    y: *mut c_void,
1272    reserve_space: *mut c_void,
1273    reserve_space_size: usize,
1274) -> cudnnStatus_t;
1275/// cuDNN: dropout backward. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1276pub type PFN_cudnnDropoutBackward = unsafe extern "C" fn(
1277    handle: cudnnHandle_t,
1278    desc: cudnnDropoutDescriptor_t,
1279    dy_desc: cudnnTensorDescriptor_t,
1280    dy: *const c_void,
1281    dx_desc: cudnnTensorDescriptor_t,
1282    dx: *mut c_void,
1283    reserve_space: *mut c_void,
1284    reserve_space_size: usize,
1285) -> cudnnStatus_t;
1286
1287// ---- RNN ------------------------------------------------------------------
1288
1289/// cuDNN: create RNN descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1290pub type PFN_cudnnCreateRNNDescriptor =
1291    unsafe extern "C" fn(desc: *mut cudnnRNNDescriptor_t) -> cudnnStatus_t;
1292/// cuDNN: destroy RNN descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1293pub type PFN_cudnnDestroyRNNDescriptor =
1294    unsafe extern "C" fn(desc: cudnnRNNDescriptor_t) -> cudnnStatus_t;
1295
1296/// cuDNN: create RNN data descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1297pub type PFN_cudnnCreateRNNDataDescriptor =
1298    unsafe extern "C" fn(desc: *mut cudnnRNNDataDescriptor_t) -> cudnnStatus_t;
1299/// cuDNN: destroy RNN data descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1300pub type PFN_cudnnDestroyRNNDataDescriptor =
1301    unsafe extern "C" fn(desc: cudnnRNNDataDescriptor_t) -> cudnnStatus_t;
1302
1303/// cuDNN: RNN forward. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1304pub type PFN_cudnnRNNForward = unsafe extern "C" fn(
1305    handle: cudnnHandle_t,
1306    rnn_desc: cudnnRNNDescriptor_t,
1307    fwd_mode: c_int,
1308    dev_seq_lengths: *const i32,
1309    x_desc: cudnnRNNDataDescriptor_t,
1310    x: *const c_void,
1311    y_desc: cudnnRNNDataDescriptor_t,
1312    y: *mut c_void,
1313    h_desc: cudnnTensorDescriptor_t,
1314    hx: *const c_void,
1315    hy: *mut c_void,
1316    c_desc: cudnnTensorDescriptor_t,
1317    cx: *const c_void,
1318    cy: *mut c_void,
1319    weight_space_size: usize,
1320    weight_space: *const c_void,
1321    work_space_size: usize,
1322    work_space: *mut c_void,
1323    reserve_space_size: usize,
1324    reserve_space: *mut c_void,
1325) -> cudnnStatus_t;
1326
1327// ---- cuDNN backend (Graph) API -------------------------------------------
1328
1329/// cuDNN: backend create descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1330pub type PFN_cudnnBackendCreateDescriptor = unsafe extern "C" fn(
1331    descriptor_type: cudnnBackendDescriptorType_t,
1332    descriptor: *mut cudnnBackendDescriptor_t,
1333) -> cudnnStatus_t;
1334/// cuDNN: backend destroy descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1335pub type PFN_cudnnBackendDestroyDescriptor =
1336    unsafe extern "C" fn(descriptor: cudnnBackendDescriptor_t) -> cudnnStatus_t;
1337/// cuDNN: backend initialize. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1338pub type PFN_cudnnBackendInitialize =
1339    unsafe extern "C" fn(descriptor: cudnnBackendDescriptor_t) -> cudnnStatus_t;
1340/// cuDNN: backend finalize. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1341pub type PFN_cudnnBackendFinalize =
1342    unsafe extern "C" fn(descriptor: cudnnBackendDescriptor_t) -> cudnnStatus_t;
1343/// cuDNN: backend set attribute. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1344pub type PFN_cudnnBackendSetAttribute = unsafe extern "C" fn(
1345    descriptor: cudnnBackendDescriptor_t,
1346    attribute_name: cudnnBackendAttributeName_t,
1347    attribute_type: cudnnBackendAttributeType_t,
1348    element_count: i64,
1349    array_of_elements: *const c_void,
1350) -> cudnnStatus_t;
1351/// cuDNN: backend get attribute. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1352pub type PFN_cudnnBackendGetAttribute = unsafe extern "C" fn(
1353    descriptor: cudnnBackendDescriptor_t,
1354    attribute_name: cudnnBackendAttributeName_t,
1355    attribute_type: cudnnBackendAttributeType_t,
1356    requested_element_count: i64,
1357    element_count: *mut i64,
1358    array_of_elements: *mut c_void,
1359) -> cudnnStatus_t;
1360/// cuDNN: backend execute. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1361pub type PFN_cudnnBackendExecute = unsafe extern "C" fn(
1362    handle: cudnnHandle_t,
1363    execution_plan: cudnnBackendDescriptor_t,
1364    variant_pack: cudnnBackendDescriptor_t,
1365) -> cudnnStatus_t;
1366
1367// ---- error-string helper -------------------------------------------------
1368
1369/// cuDNN: get error string. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1370pub type PFN_cudnnGetErrorString =
1371    unsafe extern "C" fn(status: cudnnStatus_t) -> *const core::ffi::c_char;
1372
1373// ---- N-dimensional tensor / filter descriptors --------------------------
1374
1375/// cuDNN: set tensor nd descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1376pub type PFN_cudnnSetTensorNdDescriptor = unsafe extern "C" fn(
1377    desc: cudnnTensorDescriptor_t,
1378    data_type: cudnnDataType_t,
1379    nb_dims: c_int,
1380    dim_a: *const c_int,
1381    stride_a: *const c_int,
1382) -> cudnnStatus_t;
1383
1384/// cuDNN: get tensor nd descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1385pub type PFN_cudnnGetTensorNdDescriptor = unsafe extern "C" fn(
1386    desc: cudnnTensorDescriptor_t,
1387    nb_dims_requested: c_int,
1388    data_type: *mut cudnnDataType_t,
1389    nb_dims: *mut c_int,
1390    dim_a: *mut c_int,
1391    stride_a: *mut c_int,
1392) -> cudnnStatus_t;
1393
1394/// cuDNN: set filter nd descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1395pub type PFN_cudnnSetFilterNdDescriptor = unsafe extern "C" fn(
1396    desc: cudnnFilterDescriptor_t,
1397    data_type: cudnnDataType_t,
1398    format: cudnnTensorFormat_t,
1399    nb_dims: c_int,
1400    filter_dim_a: *const c_int,
1401) -> cudnnStatus_t;
1402
1403/// cuDNN: set convolution nd descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1404pub type PFN_cudnnSetConvolutionNdDescriptor = unsafe extern "C" fn(
1405    desc: cudnnConvolutionDescriptor_t,
1406    array_length: c_int,
1407    pad_a: *const c_int,
1408    filter_stride_a: *const c_int,
1409    dilation_a: *const c_int,
1410    mode: cudnnConvolutionMode_t,
1411    compute_type: cudnnDataType_t,
1412) -> cudnnStatus_t;
1413
1414/// cuDNN: set pooling nd descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1415pub type PFN_cudnnSetPoolingNdDescriptor = unsafe extern "C" fn(
1416    desc: cudnnPoolingDescriptor_t,
1417    mode: cudnnPoolingMode_t,
1418    nan_prop: cudnnNanPropagation_t,
1419    nb_dims: c_int,
1420    window_dim_a: *const c_int,
1421    padding_a: *const c_int,
1422    stride_a: *const c_int,
1423) -> cudnnStatus_t;
1424
1425// ---- CTC loss ------------------------------------------------------------
1426
1427/// cuDNN: create CTC loss descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1428pub type PFN_cudnnCreateCTCLossDescriptor =
1429    unsafe extern "C" fn(desc: *mut cudnnCTCLossDescriptor_t) -> cudnnStatus_t;
1430/// cuDNN: destroy CTC loss descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1431pub type PFN_cudnnDestroyCTCLossDescriptor =
1432    unsafe extern "C" fn(desc: cudnnCTCLossDescriptor_t) -> cudnnStatus_t;
1433/// cuDNN: set CTC loss descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1434pub type PFN_cudnnSetCTCLossDescriptor = unsafe extern "C" fn(
1435    desc: cudnnCTCLossDescriptor_t,
1436    compute_type: cudnnDataType_t,
1437) -> cudnnStatus_t;
1438
1439/// cuDNN: get CTC loss workspace size. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1440pub type PFN_cudnnGetCTCLossWorkspaceSize = unsafe extern "C" fn(
1441    handle: cudnnHandle_t,
1442    probs_desc: cudnnTensorDescriptor_t,
1443    gradients_desc: cudnnTensorDescriptor_t,
1444    labels: *const c_int,
1445    label_lengths: *const c_int,
1446    input_lengths: *const c_int,
1447    ctc_loss_algo: c_int,
1448    ctc_loss_desc: cudnnCTCLossDescriptor_t,
1449    size_in_bytes: *mut usize,
1450) -> cudnnStatus_t;
1451
1452/// cuDNN: CTC loss. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1453pub type PFN_cudnnCTCLoss = unsafe extern "C" fn(
1454    handle: cudnnHandle_t,
1455    probs_desc: cudnnTensorDescriptor_t,
1456    probs: *const c_void,
1457    labels: *const c_int,
1458    label_lengths: *const c_int,
1459    input_lengths: *const c_int,
1460    costs: *mut c_void,
1461    gradients_desc: cudnnTensorDescriptor_t,
1462    gradients: *mut c_void,
1463    ctc_loss_algo: c_int,
1464    ctc_loss_desc: cudnnCTCLossDescriptor_t,
1465    workspace: *mut c_void,
1466    workspace_size: usize,
1467) -> cudnnStatus_t;
1468
1469// ---- RNN backward --------------------------------------------------------
1470
1471/// cuDNN: RNN backward data. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1472pub type PFN_cudnnRNNBackwardData_v8 = unsafe extern "C" fn(
1473    handle: cudnnHandle_t,
1474    rnn_desc: cudnnRNNDescriptor_t,
1475    dev_seq_lengths: *const i32,
1476    y_desc: cudnnRNNDataDescriptor_t,
1477    y: *const c_void,
1478    dy: *const c_void,
1479    x_desc: cudnnRNNDataDescriptor_t,
1480    dx: *mut c_void,
1481    h_desc: cudnnTensorDescriptor_t,
1482    hx: *const c_void,
1483    dhy: *const c_void,
1484    dhx: *mut c_void,
1485    c_desc: cudnnTensorDescriptor_t,
1486    cx: *const c_void,
1487    dcy: *const c_void,
1488    dcx: *mut c_void,
1489    weight_space_size: usize,
1490    weight_space: *const c_void,
1491    work_space_size: usize,
1492    work_space: *mut c_void,
1493    reserve_space_size: usize,
1494    reserve_space: *mut c_void,
1495) -> cudnnStatus_t;
1496
1497/// cuDNN: RNN backward weights. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1498pub type PFN_cudnnRNNBackwardWeights_v8 = unsafe extern "C" fn(
1499    handle: cudnnHandle_t,
1500    rnn_desc: cudnnRNNDescriptor_t,
1501    add_grad: c_int,
1502    dev_seq_lengths: *const i32,
1503    x_desc: cudnnRNNDataDescriptor_t,
1504    x: *const c_void,
1505    h_desc: cudnnTensorDescriptor_t,
1506    hx: *const c_void,
1507    y_desc: cudnnRNNDataDescriptor_t,
1508    y: *const c_void,
1509    weight_space_size: usize,
1510    dweight_space: *mut c_void,
1511    work_space_size: usize,
1512    work_space: *mut c_void,
1513    reserve_space_size: usize,
1514    reserve_space: *mut c_void,
1515) -> cudnnStatus_t;
1516
1517// ---- Spatial transformer --------------------------------------------------
1518
1519/// Opaque handle. Mirrors `cudnnSpatialTransformerDescriptor_t`.
1520pub type cudnnSpatialTransformerDescriptor_t = *mut c_void;
1521
1522/// cuDNN: create spatial transformer descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1523pub type PFN_cudnnCreateSpatialTransformerDescriptor =
1524    unsafe extern "C" fn(desc: *mut cudnnSpatialTransformerDescriptor_t) -> cudnnStatus_t;
1525/// cuDNN: destroy spatial transformer descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1526pub type PFN_cudnnDestroySpatialTransformerDescriptor =
1527    unsafe extern "C" fn(desc: cudnnSpatialTransformerDescriptor_t) -> cudnnStatus_t;
1528
1529/// cuDNN: set spatial transformer nd descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1530pub type PFN_cudnnSetSpatialTransformerNdDescriptor = unsafe extern "C" fn(
1531    desc: cudnnSpatialTransformerDescriptor_t,
1532    sampler_type: c_int,
1533    data_type: cudnnDataType_t,
1534    nb_dims: c_int,
1535    dim_a: *const c_int,
1536) -> cudnnStatus_t;
1537
1538/// cuDNN: spatial tf grid generator forward. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1539pub type PFN_cudnnSpatialTfGridGeneratorForward = unsafe extern "C" fn(
1540    handle: cudnnHandle_t,
1541    st_desc: cudnnSpatialTransformerDescriptor_t,
1542    theta: *const c_void,
1543    grid: *mut c_void,
1544) -> cudnnStatus_t;
1545
1546/// cuDNN: spatial tf sampler forward. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1547pub type PFN_cudnnSpatialTfSamplerForward = unsafe extern "C" fn(
1548    handle: cudnnHandle_t,
1549    st_desc: cudnnSpatialTransformerDescriptor_t,
1550    alpha: *const c_void,
1551    x_desc: cudnnTensorDescriptor_t,
1552    x: *const c_void,
1553    grid: *const c_void,
1554    beta: *const c_void,
1555    y_desc: cudnnTensorDescriptor_t,
1556    y: *mut c_void,
1557) -> cudnnStatus_t;
1558
1559// ==========================================================================
1560// Tier 1 — convolution + activation + reduction misc gaps
1561// ==========================================================================
1562
1563/// cuDNN: set convolution group count. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1564pub type PFN_cudnnSetConvolutionGroupCount = unsafe extern "C" fn(
1565    desc: cudnnConvolutionDescriptor_t,
1566    group_count: c_int,
1567) -> cudnnStatus_t;
1568/// cuDNN: get convolution group count. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1569pub type PFN_cudnnGetConvolutionGroupCount = unsafe extern "C" fn(
1570    desc: cudnnConvolutionDescriptor_t,
1571    group_count: *mut c_int,
1572) -> cudnnStatus_t;
1573
1574/// cuDNN: set convolution math type. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1575pub type PFN_cudnnSetConvolutionMathType = unsafe extern "C" fn(
1576    desc: cudnnConvolutionDescriptor_t,
1577    math_type: cudnnMathType_t,
1578) -> cudnnStatus_t;
1579/// cuDNN: get convolution math type. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1580pub type PFN_cudnnGetConvolutionMathType = unsafe extern "C" fn(
1581    desc: cudnnConvolutionDescriptor_t,
1582    math_type: *mut cudnnMathType_t,
1583) -> cudnnStatus_t;
1584
1585/// cuDNN: set convolution reorder type. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1586pub type PFN_cudnnSetConvolutionReorderType = unsafe extern "C" fn(
1587    desc: cudnnConvolutionDescriptor_t,
1588    reorder_type: cudnnReorderType_t,
1589) -> cudnnStatus_t;
1590/// cuDNN: get convolution reorder type. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1591pub type PFN_cudnnGetConvolutionReorderType = unsafe extern "C" fn(
1592    desc: cudnnConvolutionDescriptor_t,
1593    reorder_type: *mut cudnnReorderType_t,
1594) -> cudnnStatus_t;
1595
1596/// cuDNN: reorder filter and bias. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1597pub type PFN_cudnnReorderFilterAndBias = unsafe extern "C" fn(
1598    handle: cudnnHandle_t,
1599    filter_desc: cudnnFilterDescriptor_t,
1600    reorder_type: cudnnReorderType_t,
1601    filter_data: *const c_void,
1602    reordered_filter_data: *mut c_void,
1603    reorder_bias: c_int,
1604    bias_data: *const c_void,
1605    reordered_bias_data: *mut c_void,
1606) -> cudnnStatus_t;
1607
1608/// cuDNN: convolution bias activation forward. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1609pub type PFN_cudnnConvolutionBiasActivationForward = unsafe extern "C" fn(
1610    handle: cudnnHandle_t,
1611    alpha1: *const c_void,
1612    x_desc: cudnnTensorDescriptor_t,
1613    x: *const c_void,
1614    w_desc: cudnnFilterDescriptor_t,
1615    w: *const c_void,
1616    conv_desc: cudnnConvolutionDescriptor_t,
1617    algo: cudnnConvolutionFwdAlgo_t,
1618    workspace: *mut c_void,
1619    workspace_size: usize,
1620    alpha2: *const c_void,
1621    z_desc: cudnnTensorDescriptor_t,
1622    z: *const c_void,
1623    bias_desc: cudnnTensorDescriptor_t,
1624    bias: *const c_void,
1625    activation_desc: cudnnActivationDescriptor_t,
1626    y_desc: cudnnTensorDescriptor_t,
1627    y: *mut c_void,
1628) -> cudnnStatus_t;
1629
1630/// cuDNN: activation backward. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1631pub type PFN_cudnnActivationBackward = unsafe extern "C" fn(
1632    handle: cudnnHandle_t,
1633    activation_desc: cudnnActivationDescriptor_t,
1634    alpha: *const c_void,
1635    y_desc: cudnnTensorDescriptor_t,
1636    y: *const c_void,
1637    dy_desc: cudnnTensorDescriptor_t,
1638    dy: *const c_void,
1639    x_desc: cudnnTensorDescriptor_t,
1640    x: *const c_void,
1641    beta: *const c_void,
1642    dx_desc: cudnnTensorDescriptor_t,
1643    dx: *mut c_void,
1644) -> cudnnStatus_t;
1645
1646/// cuDNN: set activation descriptor swish beta. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1647pub type PFN_cudnnSetActivationDescriptorSwishBeta = unsafe extern "C" fn(
1648    desc: cudnnActivationDescriptor_t,
1649    swish_beta: c_double,
1650) -> cudnnStatus_t;
1651/// cuDNN: get activation descriptor swish beta. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1652pub type PFN_cudnnGetActivationDescriptorSwishBeta = unsafe extern "C" fn(
1653    desc: cudnnActivationDescriptor_t,
1654    swish_beta: *mut c_double,
1655) -> cudnnStatus_t;
1656
1657/// cuDNN: LRN cross channel backward. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1658pub type PFN_cudnnLRNCrossChannelBackward = unsafe extern "C" fn(
1659    handle: cudnnHandle_t,
1660    norm_desc: cudnnLRNDescriptor_t,
1661    lrn_mode: c_int,
1662    alpha: *const c_void,
1663    y_desc: cudnnTensorDescriptor_t,
1664    y: *const c_void,
1665    dy_desc: cudnnTensorDescriptor_t,
1666    dy: *const c_void,
1667    x_desc: cudnnTensorDescriptor_t,
1668    x: *const c_void,
1669    beta: *const c_void,
1670    dx_desc: cudnnTensorDescriptor_t,
1671    dx: *mut c_void,
1672) -> cudnnStatus_t;
1673
1674/// cuDNN: divisive normalization forward. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1675pub type PFN_cudnnDivisiveNormalizationForward = unsafe extern "C" fn(
1676    handle: cudnnHandle_t,
1677    norm_desc: cudnnLRNDescriptor_t,
1678    mode: c_int,
1679    alpha: *const c_void,
1680    x_desc: cudnnTensorDescriptor_t,
1681    x: *const c_void,
1682    means: *const c_void,
1683    temp: *mut c_void,
1684    temp2: *mut c_void,
1685    beta: *const c_void,
1686    y_desc: cudnnTensorDescriptor_t,
1687    y: *mut c_void,
1688) -> cudnnStatus_t;
1689
1690/// cuDNN: divisive normalization backward. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1691pub type PFN_cudnnDivisiveNormalizationBackward = unsafe extern "C" fn(
1692    handle: cudnnHandle_t,
1693    norm_desc: cudnnLRNDescriptor_t,
1694    mode: c_int,
1695    alpha: *const c_void,
1696    x_desc: cudnnTensorDescriptor_t,
1697    x: *const c_void,
1698    means: *const c_void,
1699    dy: *const c_void,
1700    temp: *mut c_void,
1701    temp2: *mut c_void,
1702    beta: *const c_void,
1703    d_xdmeans_desc: cudnnTensorDescriptor_t,
1704    dx: *mut c_void,
1705    d_means: *mut c_void,
1706) -> cudnnStatus_t;
1707
1708/// cuDNN: get reduction indices size. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1709pub type PFN_cudnnGetReductionIndicesSize = unsafe extern "C" fn(
1710    handle: cudnnHandle_t,
1711    desc: cudnnReduceTensorDescriptor_t,
1712    a_desc: cudnnTensorDescriptor_t,
1713    c_desc: cudnnTensorDescriptor_t,
1714    size_in_bytes: *mut usize,
1715) -> cudnnStatus_t;
1716
1717// 4-D tensor / filter readback + strided-Set.
1718
1719/// cuDNN: set tensor4d descriptor ex. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1720pub type PFN_cudnnSetTensor4dDescriptorEx = unsafe extern "C" fn(
1721    desc: cudnnTensorDescriptor_t,
1722    data_type: cudnnDataType_t,
1723    n: c_int,
1724    c: c_int,
1725    h: c_int,
1726    w: c_int,
1727    n_stride: c_int,
1728    c_stride: c_int,
1729    h_stride: c_int,
1730    w_stride: c_int,
1731) -> cudnnStatus_t;
1732
1733/// cuDNN: get tensor4d descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1734pub type PFN_cudnnGetTensor4dDescriptor = unsafe extern "C" fn(
1735    desc: cudnnTensorDescriptor_t,
1736    data_type: *mut cudnnDataType_t,
1737    n: *mut c_int,
1738    c: *mut c_int,
1739    h: *mut c_int,
1740    w: *mut c_int,
1741    n_stride: *mut c_int,
1742    c_stride: *mut c_int,
1743    h_stride: *mut c_int,
1744    w_stride: *mut c_int,
1745) -> cudnnStatus_t;
1746
1747/// cuDNN: get filter4d descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1748pub type PFN_cudnnGetFilter4dDescriptor = unsafe extern "C" fn(
1749    desc: cudnnFilterDescriptor_t,
1750    data_type: *mut cudnnDataType_t,
1751    format: *mut cudnnTensorFormat_t,
1752    k: *mut c_int,
1753    c: *mut c_int,
1754    h: *mut c_int,
1755    w: *mut c_int,
1756) -> cudnnStatus_t;
1757
1758// Dropout descriptor save/restore.
1759
1760/// cuDNN: get dropout descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1761pub type PFN_cudnnGetDropoutDescriptor = unsafe extern "C" fn(
1762    desc: cudnnDropoutDescriptor_t,
1763    handle: cudnnHandle_t,
1764    dropout: *mut f32,
1765    states: *mut *mut c_void,
1766    seed: *mut u64,
1767) -> cudnnStatus_t;
1768
1769/// cuDNN: restore dropout descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1770pub type PFN_cudnnRestoreDropoutDescriptor = unsafe extern "C" fn(
1771    desc: cudnnDropoutDescriptor_t,
1772    handle: cudnnHandle_t,
1773    dropout: f32,
1774    states: *mut c_void,
1775    state_size: usize,
1776    seed: u64,
1777) -> cudnnStatus_t;
1778
1779// ==========================================================================
1780// Tier 2 — algorithm finders / pickers (v7)
1781// ==========================================================================
1782
1783/// cuDNN: get convolution forward algorithm. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1784pub type PFN_cudnnGetConvolutionForwardAlgorithm_v7 = unsafe extern "C" fn(
1785    handle: cudnnHandle_t,
1786    src_desc: cudnnTensorDescriptor_t,
1787    filter_desc: cudnnFilterDescriptor_t,
1788    conv_desc: cudnnConvolutionDescriptor_t,
1789    dst_desc: cudnnTensorDescriptor_t,
1790    requested_algo_count: c_int,
1791    returned_algo_count: *mut c_int,
1792    perf_results: *mut cudnnConvolutionFwdAlgoPerf_t,
1793) -> cudnnStatus_t;
1794
1795/// cuDNN: find convolution forward algorithm. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1796pub type PFN_cudnnFindConvolutionForwardAlgorithm = unsafe extern "C" fn(
1797    handle: cudnnHandle_t,
1798    src_desc: cudnnTensorDescriptor_t,
1799    filter_desc: cudnnFilterDescriptor_t,
1800    conv_desc: cudnnConvolutionDescriptor_t,
1801    dst_desc: cudnnTensorDescriptor_t,
1802    requested_algo_count: c_int,
1803    returned_algo_count: *mut c_int,
1804    perf_results: *mut cudnnConvolutionFwdAlgoPerf_t,
1805) -> cudnnStatus_t;
1806
1807/// cuDNN: find convolution forward algorithm ex. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1808pub type PFN_cudnnFindConvolutionForwardAlgorithmEx = unsafe extern "C" fn(
1809    handle: cudnnHandle_t,
1810    src_desc: cudnnTensorDescriptor_t,
1811    src: *const c_void,
1812    filter_desc: cudnnFilterDescriptor_t,
1813    filter: *const c_void,
1814    conv_desc: cudnnConvolutionDescriptor_t,
1815    dst_desc: cudnnTensorDescriptor_t,
1816    dst: *mut c_void,
1817    requested_algo_count: c_int,
1818    returned_algo_count: *mut c_int,
1819    perf_results: *mut cudnnConvolutionFwdAlgoPerf_t,
1820    workspace: *mut c_void,
1821    workspace_size: usize,
1822) -> cudnnStatus_t;
1823
1824/// cuDNN: get convolution backward data algorithm. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1825pub type PFN_cudnnGetConvolutionBackwardDataAlgorithm_v7 = unsafe extern "C" fn(
1826    handle: cudnnHandle_t,
1827    filter_desc: cudnnFilterDescriptor_t,
1828    diff_desc: cudnnTensorDescriptor_t,
1829    conv_desc: cudnnConvolutionDescriptor_t,
1830    grad_desc: cudnnTensorDescriptor_t,
1831    requested_algo_count: c_int,
1832    returned_algo_count: *mut c_int,
1833    perf_results: *mut cudnnConvolutionBwdDataAlgoPerf_t,
1834) -> cudnnStatus_t;
1835
1836/// cuDNN: find convolution backward data algorithm. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1837pub type PFN_cudnnFindConvolutionBackwardDataAlgorithm = unsafe extern "C" fn(
1838    handle: cudnnHandle_t,
1839    filter_desc: cudnnFilterDescriptor_t,
1840    diff_desc: cudnnTensorDescriptor_t,
1841    conv_desc: cudnnConvolutionDescriptor_t,
1842    grad_desc: cudnnTensorDescriptor_t,
1843    requested_algo_count: c_int,
1844    returned_algo_count: *mut c_int,
1845    perf_results: *mut cudnnConvolutionBwdDataAlgoPerf_t,
1846) -> cudnnStatus_t;
1847
1848/// cuDNN: get convolution backward filter algorithm. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1849pub type PFN_cudnnGetConvolutionBackwardFilterAlgorithm_v7 = unsafe extern "C" fn(
1850    handle: cudnnHandle_t,
1851    src_desc: cudnnTensorDescriptor_t,
1852    diff_desc: cudnnTensorDescriptor_t,
1853    conv_desc: cudnnConvolutionDescriptor_t,
1854    grad_desc: cudnnFilterDescriptor_t,
1855    requested_algo_count: c_int,
1856    returned_algo_count: *mut c_int,
1857    perf_results: *mut cudnnConvolutionBwdFilterAlgoPerf_t,
1858) -> cudnnStatus_t;
1859
1860/// cuDNN: find convolution backward filter algorithm. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1861pub type PFN_cudnnFindConvolutionBackwardFilterAlgorithm = unsafe extern "C" fn(
1862    handle: cudnnHandle_t,
1863    src_desc: cudnnTensorDescriptor_t,
1864    diff_desc: cudnnTensorDescriptor_t,
1865    conv_desc: cudnnConvolutionDescriptor_t,
1866    grad_desc: cudnnFilterDescriptor_t,
1867    requested_algo_count: c_int,
1868    returned_algo_count: *mut c_int,
1869    perf_results: *mut cudnnConvolutionBwdFilterAlgoPerf_t,
1870) -> cudnnStatus_t;
1871
1872// ==========================================================================
1873// Tier 3 — BatchNorm "Ex" + generic Normalization API
1874// ==========================================================================
1875
1876/// cuDNN: get batch normalization forward training ex workspace size. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1877pub type PFN_cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize =
1878    unsafe extern "C" fn(
1879        handle: cudnnHandle_t,
1880        mode: cudnnBatchNormMode_t,
1881        bn_ops: cudnnBatchNormOps_t,
1882        x_desc: cudnnTensorDescriptor_t,
1883        z_desc: cudnnTensorDescriptor_t,
1884        y_desc: cudnnTensorDescriptor_t,
1885        bn_scale_bias_mean_var_desc: cudnnTensorDescriptor_t,
1886        activation_desc: cudnnActivationDescriptor_t,
1887        size_in_bytes: *mut usize,
1888    ) -> cudnnStatus_t;
1889
1890/// cuDNN: get batch normalization backward ex workspace size. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1891pub type PFN_cudnnGetBatchNormalizationBackwardExWorkspaceSize =
1892    unsafe extern "C" fn(
1893        handle: cudnnHandle_t,
1894        mode: cudnnBatchNormMode_t,
1895        bn_ops: cudnnBatchNormOps_t,
1896        x_desc: cudnnTensorDescriptor_t,
1897        y_desc: cudnnTensorDescriptor_t,
1898        dy_desc: cudnnTensorDescriptor_t,
1899        dz_desc: cudnnTensorDescriptor_t,
1900        dx_desc: cudnnTensorDescriptor_t,
1901        d_bn_scale_bias_desc: cudnnTensorDescriptor_t,
1902        activation_desc: cudnnActivationDescriptor_t,
1903        size_in_bytes: *mut usize,
1904    ) -> cudnnStatus_t;
1905
1906/// cuDNN: get batch normalization training ex reserve space size. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1907pub type PFN_cudnnGetBatchNormalizationTrainingExReserveSpaceSize =
1908    unsafe extern "C" fn(
1909        handle: cudnnHandle_t,
1910        mode: cudnnBatchNormMode_t,
1911        bn_ops: cudnnBatchNormOps_t,
1912        activation_desc: cudnnActivationDescriptor_t,
1913        x_desc: cudnnTensorDescriptor_t,
1914        size_in_bytes: *mut usize,
1915    ) -> cudnnStatus_t;
1916
1917/// cuDNN: batch normalization forward training ex. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1918pub type PFN_cudnnBatchNormalizationForwardTrainingEx = unsafe extern "C" fn(
1919    handle: cudnnHandle_t,
1920    mode: cudnnBatchNormMode_t,
1921    bn_ops: cudnnBatchNormOps_t,
1922    alpha: *const c_void,
1923    beta: *const c_void,
1924    x_desc: cudnnTensorDescriptor_t,
1925    x: *const c_void,
1926    z_desc: cudnnTensorDescriptor_t,
1927    z: *const c_void,
1928    y_desc: cudnnTensorDescriptor_t,
1929    y: *mut c_void,
1930    bn_scale_bias_mean_var_desc: cudnnTensorDescriptor_t,
1931    bn_scale: *const c_void,
1932    bn_bias: *const c_void,
1933    exponential_average_factor: c_double,
1934    result_running_mean: *mut c_void,
1935    result_running_variance: *mut c_void,
1936    epsilon: c_double,
1937    save_mean: *mut c_void,
1938    save_inv_variance: *mut c_void,
1939    activation_desc: cudnnActivationDescriptor_t,
1940    workspace: *mut c_void,
1941    workspace_size: usize,
1942    reserve_space: *mut c_void,
1943    reserve_space_size: usize,
1944) -> cudnnStatus_t;
1945
1946/// cuDNN: batch normalization backward ex. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1947pub type PFN_cudnnBatchNormalizationBackwardEx = unsafe extern "C" fn(
1948    handle: cudnnHandle_t,
1949    mode: cudnnBatchNormMode_t,
1950    bn_ops: cudnnBatchNormOps_t,
1951    alpha_data_diff: *const c_void,
1952    beta_data_diff: *const c_void,
1953    alpha_param_diff: *const c_void,
1954    beta_param_diff: *const c_void,
1955    x_desc: cudnnTensorDescriptor_t,
1956    x: *const c_void,
1957    y_desc: cudnnTensorDescriptor_t,
1958    y: *const c_void,
1959    dy_desc: cudnnTensorDescriptor_t,
1960    dy: *const c_void,
1961    dz_desc: cudnnTensorDescriptor_t,
1962    dz: *mut c_void,
1963    dx_desc: cudnnTensorDescriptor_t,
1964    dx: *mut c_void,
1965    d_bn_scale_bias_desc: cudnnTensorDescriptor_t,
1966    bn_scale: *const c_void,
1967    bn_bias: *const c_void,
1968    d_bn_scale_result: *mut c_void,
1969    d_bn_bias_result: *mut c_void,
1970    epsilon: c_double,
1971    saved_mean: *const c_void,
1972    saved_inv_variance: *const c_void,
1973    activation_desc: cudnnActivationDescriptor_t,
1974    workspace: *mut c_void,
1975    workspace_size: usize,
1976    reserve_space: *mut c_void,
1977    reserve_space_size: usize,
1978) -> cudnnStatus_t;
1979
1980// Generic Normalization API (cuDNN 8+).
1981
1982/// cuDNN: normalization forward inference. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
1983pub type PFN_cudnnNormalizationForwardInference = unsafe extern "C" fn(
1984    handle: cudnnHandle_t,
1985    mode: cudnnNormMode_t,
1986    norm_ops: cudnnNormOps_t,
1987    algo: cudnnNormAlgo_t,
1988    alpha: *const c_void,
1989    beta: *const c_void,
1990    x_desc: cudnnTensorDescriptor_t,
1991    x: *const c_void,
1992    norm_scale_bias_desc: cudnnTensorDescriptor_t,
1993    norm_scale: *const c_void,
1994    norm_bias: *const c_void,
1995    norm_mean_var_desc: cudnnTensorDescriptor_t,
1996    estimated_mean: *const c_void,
1997    estimated_variance: *const c_void,
1998    z_desc: cudnnTensorDescriptor_t,
1999    z: *const c_void,
2000    activation_desc: cudnnActivationDescriptor_t,
2001    y_desc: cudnnTensorDescriptor_t,
2002    y: *mut c_void,
2003    epsilon: c_double,
2004    group_count: c_int,
2005) -> cudnnStatus_t;
2006
2007/// cuDNN: get normalization forward training workspace size. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2008pub type PFN_cudnnGetNormalizationForwardTrainingWorkspaceSize =
2009    unsafe extern "C" fn(
2010        handle: cudnnHandle_t,
2011        mode: cudnnNormMode_t,
2012        norm_ops: cudnnNormOps_t,
2013        algo: cudnnNormAlgo_t,
2014        x_desc: cudnnTensorDescriptor_t,
2015        z_desc: cudnnTensorDescriptor_t,
2016        y_desc: cudnnTensorDescriptor_t,
2017        norm_scale_bias_desc: cudnnTensorDescriptor_t,
2018        activation_desc: cudnnActivationDescriptor_t,
2019        norm_mean_var_desc: cudnnTensorDescriptor_t,
2020        size_in_bytes: *mut usize,
2021        group_count: c_int,
2022    ) -> cudnnStatus_t;
2023
2024/// cuDNN: get normalization backward workspace size. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2025pub type PFN_cudnnGetNormalizationBackwardWorkspaceSize = unsafe extern "C" fn(
2026    handle: cudnnHandle_t,
2027    mode: cudnnNormMode_t,
2028    norm_ops: cudnnNormOps_t,
2029    algo: cudnnNormAlgo_t,
2030    x_desc: cudnnTensorDescriptor_t,
2031    y_desc: cudnnTensorDescriptor_t,
2032    dy_desc: cudnnTensorDescriptor_t,
2033    dz_desc: cudnnTensorDescriptor_t,
2034    dx_desc: cudnnTensorDescriptor_t,
2035    d_norm_scale_bias_desc: cudnnTensorDescriptor_t,
2036    activation_desc: cudnnActivationDescriptor_t,
2037    norm_mean_var_desc: cudnnTensorDescriptor_t,
2038    size_in_bytes: *mut usize,
2039    group_count: c_int,
2040) -> cudnnStatus_t;
2041
2042/// cuDNN: get normalization training reserve space size. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2043pub type PFN_cudnnGetNormalizationTrainingReserveSpaceSize =
2044    unsafe extern "C" fn(
2045        handle: cudnnHandle_t,
2046        mode: cudnnNormMode_t,
2047        norm_ops: cudnnNormOps_t,
2048        algo: cudnnNormAlgo_t,
2049        activation_desc: cudnnActivationDescriptor_t,
2050        x_desc: cudnnTensorDescriptor_t,
2051        size_in_bytes: *mut usize,
2052        group_count: c_int,
2053    ) -> cudnnStatus_t;
2054
2055/// cuDNN: normalization forward training. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2056pub type PFN_cudnnNormalizationForwardTraining = unsafe extern "C" fn(
2057    handle: cudnnHandle_t,
2058    mode: cudnnNormMode_t,
2059    norm_ops: cudnnNormOps_t,
2060    algo: cudnnNormAlgo_t,
2061    alpha: *const c_void,
2062    beta: *const c_void,
2063    x_desc: cudnnTensorDescriptor_t,
2064    x: *const c_void,
2065    norm_scale_bias_desc: cudnnTensorDescriptor_t,
2066    norm_scale: *const c_void,
2067    norm_bias: *const c_void,
2068    exponential_average_factor: c_double,
2069    norm_mean_var_desc: cudnnTensorDescriptor_t,
2070    result_running_mean: *mut c_void,
2071    result_running_variance: *mut c_void,
2072    epsilon: c_double,
2073    save_mean: *mut c_void,
2074    save_inv_variance: *mut c_void,
2075    activation_desc: cudnnActivationDescriptor_t,
2076    z_desc: cudnnTensorDescriptor_t,
2077    z: *const c_void,
2078    y_desc: cudnnTensorDescriptor_t,
2079    y: *mut c_void,
2080    workspace: *mut c_void,
2081    workspace_size: usize,
2082    reserve_space: *mut c_void,
2083    reserve_space_size: usize,
2084    group_count: c_int,
2085) -> cudnnStatus_t;
2086
2087/// cuDNN: normalization backward. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2088pub type PFN_cudnnNormalizationBackward = unsafe extern "C" fn(
2089    handle: cudnnHandle_t,
2090    mode: cudnnNormMode_t,
2091    norm_ops: cudnnNormOps_t,
2092    algo: cudnnNormAlgo_t,
2093    alpha_data_diff: *const c_void,
2094    beta_data_diff: *const c_void,
2095    alpha_param_diff: *const c_void,
2096    beta_param_diff: *const c_void,
2097    x_desc: cudnnTensorDescriptor_t,
2098    x: *const c_void,
2099    y_desc: cudnnTensorDescriptor_t,
2100    y: *const c_void,
2101    dy_desc: cudnnTensorDescriptor_t,
2102    dy: *const c_void,
2103    dz_desc: cudnnTensorDescriptor_t,
2104    dz: *mut c_void,
2105    dx_desc: cudnnTensorDescriptor_t,
2106    dx: *mut c_void,
2107    d_norm_scale_bias_desc: cudnnTensorDescriptor_t,
2108    norm_scale: *const c_void,
2109    norm_bias: *const c_void,
2110    d_norm_scale: *mut c_void,
2111    d_norm_bias: *mut c_void,
2112    epsilon: c_double,
2113    norm_mean_var_desc: cudnnTensorDescriptor_t,
2114    saved_mean: *const c_void,
2115    saved_inv_variance: *const c_void,
2116    activation_desc: cudnnActivationDescriptor_t,
2117    workspace: *mut c_void,
2118    workspace_size: usize,
2119    reserve_space: *mut c_void,
2120    reserve_space_size: usize,
2121    group_count: c_int,
2122) -> cudnnStatus_t;
2123
2124// ==========================================================================
2125// Tier 4 — RNN v8 modernization
2126// ==========================================================================
2127
2128/// cuDNN: set RNN descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2129pub type PFN_cudnnSetRNNDescriptor_v8 = unsafe extern "C" fn(
2130    rnn_desc: cudnnRNNDescriptor_t,
2131    algo: cudnnRNNAlgo_t,
2132    cell_mode: cudnnRNNMode_t,
2133    bias_mode: c_int,
2134    dir_mode: cudnnDirectionMode_t,
2135    input_mode: cudnnRNNInputMode_t,
2136    data_type: cudnnDataType_t,
2137    math_prec: cudnnDataType_t,
2138    math_type: cudnnMathType_t,
2139    input_size: i32,
2140    hidden_size: i32,
2141    proj_size: i32,
2142    num_layers: i32,
2143    dropout_desc: cudnnDropoutDescriptor_t,
2144    aux_flags: u32,
2145) -> cudnnStatus_t;
2146
2147/// cuDNN: build RNN dynamic. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2148pub type PFN_cudnnBuildRNNDynamic = unsafe extern "C" fn(
2149    handle: cudnnHandle_t,
2150    rnn_desc: cudnnRNNDescriptor_t,
2151    mini_batch: c_int,
2152) -> cudnnStatus_t;
2153
2154/// cuDNN: get RNN temp space sizes. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2155pub type PFN_cudnnGetRNNTempSpaceSizes = unsafe extern "C" fn(
2156    handle: cudnnHandle_t,
2157    rnn_desc: cudnnRNNDescriptor_t,
2158    fwd_mode: c_int,
2159    x_desc: cudnnRNNDataDescriptor_t,
2160    work_space_size: *mut usize,
2161    reserve_space_size: *mut usize,
2162) -> cudnnStatus_t;
2163
2164/// cuDNN: get RNN weight space size. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2165pub type PFN_cudnnGetRNNWeightSpaceSize = unsafe extern "C" fn(
2166    handle: cudnnHandle_t,
2167    rnn_desc: cudnnRNNDescriptor_t,
2168    weight_space_size: *mut usize,
2169) -> cudnnStatus_t;
2170
2171/// cuDNN: get RNN weight params. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2172pub type PFN_cudnnGetRNNWeightParams = unsafe extern "C" fn(
2173    handle: cudnnHandle_t,
2174    rnn_desc: cudnnRNNDescriptor_t,
2175    pseudo_layer: i32,
2176    weight_space_size: usize,
2177    weight_space: *const c_void,
2178    lin_layer_id: i32,
2179    m_desc: cudnnTensorDescriptor_t,
2180    m_addr: *mut *mut c_void,
2181    b_desc: cudnnTensorDescriptor_t,
2182    b_addr: *mut *mut c_void,
2183) -> cudnnStatus_t;
2184
2185// ==========================================================================
2186// Tier 5 — Multi-head attention
2187// ==========================================================================
2188
2189/// cuDNN: create attn descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2190pub type PFN_cudnnCreateAttnDescriptor =
2191    unsafe extern "C" fn(desc: *mut cudnnAttnDescriptor_t) -> cudnnStatus_t;
2192/// cuDNN: destroy attn descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2193pub type PFN_cudnnDestroyAttnDescriptor =
2194    unsafe extern "C" fn(desc: cudnnAttnDescriptor_t) -> cudnnStatus_t;
2195
2196/// cuDNN: set attn descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2197pub type PFN_cudnnSetAttnDescriptor = unsafe extern "C" fn(
2198    desc: cudnnAttnDescriptor_t,
2199    attn_mode: u32,
2200    n_heads: i32,
2201    sm_scaler: c_double,
2202    data_type: cudnnDataType_t,
2203    compute_prec: cudnnDataType_t,
2204    math_type: cudnnMathType_t,
2205    attn_dropout_desc: cudnnDropoutDescriptor_t,
2206    post_dropout_desc: cudnnDropoutDescriptor_t,
2207    q_size: i32,
2208    k_size: i32,
2209    v_size: i32,
2210    q_proj_size: i32,
2211    k_proj_size: i32,
2212    v_proj_size: i32,
2213    o_proj_size: i32,
2214    qo_max_seq_length: i32,
2215    kv_max_seq_length: i32,
2216    max_batch_size: i32,
2217    max_beam_size: i32,
2218) -> cudnnStatus_t;
2219
2220/// cuDNN: get multi head attn buffers. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2221pub type PFN_cudnnGetMultiHeadAttnBuffers = unsafe extern "C" fn(
2222    handle: cudnnHandle_t,
2223    attn_desc: cudnnAttnDescriptor_t,
2224    weight_size_in_bytes: *mut usize,
2225    work_space_size_in_bytes: *mut usize,
2226    reserve_space_size_in_bytes: *mut usize,
2227) -> cudnnStatus_t;
2228
2229/// cuDNN: get multi head attn weights. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2230pub type PFN_cudnnGetMultiHeadAttnWeights = unsafe extern "C" fn(
2231    handle: cudnnHandle_t,
2232    attn_desc: cudnnAttnDescriptor_t,
2233    w_kind: c_int,
2234    weight_size_in_bytes: usize,
2235    weights: *const c_void,
2236    w_desc: cudnnTensorDescriptor_t,
2237    w_addr: *mut *mut c_void,
2238) -> cudnnStatus_t;
2239
2240/// cuDNN: multi head attn forward. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2241pub type PFN_cudnnMultiHeadAttnForward = unsafe extern "C" fn(
2242    handle: cudnnHandle_t,
2243    attn_desc: cudnnAttnDescriptor_t,
2244    curr_idx: i32,
2245    lo_win_idx: *const i32,
2246    hi_win_idx: *const i32,
2247    dev_seq_lengths_qo: *const i32,
2248    dev_seq_lengths_kv: *const i32,
2249    q_desc: cudnnSeqDataDescriptor_t,
2250    queries: *const c_void,
2251    residuals: *const c_void,
2252    k_desc: cudnnSeqDataDescriptor_t,
2253    keys: *const c_void,
2254    v_desc: cudnnSeqDataDescriptor_t,
2255    values: *const c_void,
2256    o_desc: cudnnSeqDataDescriptor_t,
2257    out: *mut c_void,
2258    weight_size_in_bytes: usize,
2259    weights: *const c_void,
2260    work_space_size_in_bytes: usize,
2261    work_space: *mut c_void,
2262    reserve_space_size_in_bytes: usize,
2263    reserve_space: *mut c_void,
2264) -> cudnnStatus_t;
2265
2266/// cuDNN: multi head attn backward data. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2267pub type PFN_cudnnMultiHeadAttnBackwardData = unsafe extern "C" fn(
2268    handle: cudnnHandle_t,
2269    attn_desc: cudnnAttnDescriptor_t,
2270    lo_win_idx: *const i32,
2271    hi_win_idx: *const i32,
2272    dev_seq_lengths_dqdo: *const i32,
2273    dev_seq_lengths_dkdv: *const i32,
2274    do_desc: cudnnSeqDataDescriptor_t,
2275    dout: *const c_void,
2276    dq_desc: cudnnSeqDataDescriptor_t,
2277    dqueries: *mut c_void,
2278    queries: *const c_void,
2279    dk_desc: cudnnSeqDataDescriptor_t,
2280    dkeys: *mut c_void,
2281    keys: *const c_void,
2282    dv_desc: cudnnSeqDataDescriptor_t,
2283    dvalues: *mut c_void,
2284    values: *const c_void,
2285    weight_size_in_bytes: usize,
2286    weights: *const c_void,
2287    work_space_size_in_bytes: usize,
2288    work_space: *mut c_void,
2289    reserve_space_size_in_bytes: usize,
2290    reserve_space: *mut c_void,
2291) -> cudnnStatus_t;
2292
2293/// cuDNN: multi head attn backward weights. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2294pub type PFN_cudnnMultiHeadAttnBackwardWeights = unsafe extern "C" fn(
2295    handle: cudnnHandle_t,
2296    attn_desc: cudnnAttnDescriptor_t,
2297    add_grad: c_int,
2298    q_desc: cudnnSeqDataDescriptor_t,
2299    queries: *const c_void,
2300    k_desc: cudnnSeqDataDescriptor_t,
2301    keys: *const c_void,
2302    v_desc: cudnnSeqDataDescriptor_t,
2303    values: *const c_void,
2304    do_desc: cudnnSeqDataDescriptor_t,
2305    dout: *const c_void,
2306    weight_size_in_bytes: usize,
2307    weights: *const c_void,
2308    dweights: *mut c_void,
2309    work_space_size_in_bytes: usize,
2310    work_space: *mut c_void,
2311    reserve_space_size_in_bytes: usize,
2312    reserve_space: *mut c_void,
2313) -> cudnnStatus_t;
2314
2315// SeqDataDescriptor lifetime helpers.
2316
2317/// cuDNN: create seq data descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2318pub type PFN_cudnnCreateSeqDataDescriptor =
2319    unsafe extern "C" fn(desc: *mut cudnnSeqDataDescriptor_t) -> cudnnStatus_t;
2320/// cuDNN: destroy seq data descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2321pub type PFN_cudnnDestroySeqDataDescriptor =
2322    unsafe extern "C" fn(desc: cudnnSeqDataDescriptor_t) -> cudnnStatus_t;
2323/// cuDNN: set seq data descriptor. See <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>.
2324pub type PFN_cudnnSetSeqDataDescriptor = unsafe extern "C" fn(
2325    desc: cudnnSeqDataDescriptor_t,
2326    data_type: cudnnDataType_t,
2327    nb_dims: c_int,
2328    dim_a: *const c_int,
2329    axes: *const c_int,
2330    seq_length_array_size: usize,
2331    seq_length_array: *const c_int,
2332    padding_fill: *const c_void,
2333) -> cudnnStatus_t;
2334
2335// ---- loader --------------------------------------------------------------
2336
2337/// cuDNN's install layout is non-standard — on Windows the DLLs live at
2338/// `C:\Program Files\NVIDIA\CUDNN\v<ver>\bin\<cuda_major>\` and are not on
2339/// the default DLL search path. Probe the common locations.
2340fn cudnn_candidates() -> Vec<String> {
2341    #[cfg(target_os = "linux")]
2342    {
2343        vec![
2344            "libcudnn.so.9".to_string(),
2345            "libcudnn.so.8".to_string(),
2346            "libcudnn.so".to_string(),
2347        ]
2348    }
2349    #[cfg(target_os = "windows")]
2350    {
2351        vec!["cudnn64_9.dll".to_string(), "cudnn64_8.dll".to_string()]
2352    }
2353    #[cfg(not(any(target_os = "linux", target_os = "windows")))]
2354    {
2355        Vec::new()
2356    }
2357}
2358
2359/// Detect the CUDA toolkit major version that cuDNN should be paired
2360/// against. Returns `None` if no signal is available.
2361///
2362/// Strategy (Windows-style env vars work on Linux too if set):
2363///   1. `CUDA_PATH` typically ends in `vNN.M` — e.g.
2364///      `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6` →
2365///      `Some(12)`.
2366///   2. Fall back to scanning `CUDA_PATH_V<NN>_<M>` env vars and
2367///      picking the highest `<NN>` present.
2368fn detect_cuda_major() -> Option<u32> {
2369    if let Ok(p) = std::env::var("CUDA_PATH") {
2370        if let Some(n) = parse_cuda_major_from_path(&p) {
2371            return Some(n);
2372        }
2373    }
2374    // CUDA_PATH_V12_6=... CUDA_PATH_V11_8=... — pick the highest.
2375    let mut best: Option<u32> = None;
2376    for (k, _) in std::env::vars() {
2377        if let Some(rest) = k.strip_prefix("CUDA_PATH_V") {
2378            // rest looks like "12_6"
2379            if let Some((maj, _)) = rest.split_once('_') {
2380                if let Ok(n) = maj.parse::<u32>() {
2381                    best = Some(best.map_or(n, |b| b.max(n)));
2382                }
2383            }
2384        }
2385    }
2386    best
2387}
2388
2389/// Look for a path component matching `vNN.M` (case-insensitive) and
2390/// return `NN`.
2391fn parse_cuda_major_from_path(p: &str) -> Option<u32> {
2392    for component in p.split(['/', '\\']) {
2393        let bytes = component.as_bytes();
2394        let rest = if bytes.first() == Some(&b'v') || bytes.first() == Some(&b'V') {
2395            &component[1..]
2396        } else {
2397            continue;
2398        };
2399        let dot = rest.find('.')?;
2400        if let Ok(n) = rest[..dot].parse::<u32>() {
2401            return Some(n);
2402        }
2403    }
2404    None
2405}
2406
2407fn cudnn_extra_search_dirs() -> Vec<PathBuf> {
2408    let mut out = Vec::new();
2409
2410    if let Ok(p) = std::env::var("CUDNN_PATH") {
2411        let base = PathBuf::from(p);
2412        if cfg!(target_os = "windows") {
2413            out.push(base.join("bin"));
2414        } else {
2415            out.push(base.join("lib"));
2416            out.push(base.join("lib64"));
2417        }
2418    }
2419
2420    if cfg!(target_os = "windows") {
2421        // Default Windows install: `C:\Program Files\NVIDIA\CUDNN\v*\bin\<cuda_major>`.
2422        // The numeric subdirectory under `bin/` is the CUDA major version this
2423        // cuDNN build targets. We must NOT push every subdirectory blindly —
2424        // on a host with cuDNN installed for multiple CUDA majors that puts
2425        // the wrong DLL flavor on the search path. Prefer the subdir matching
2426        // the detected CUDA major; fall back to highest-major-first when we
2427        // can't tell.
2428        let target_major = detect_cuda_major();
2429        if let Ok(pf) = std::env::var("ProgramFiles") {
2430            let cudnn_root = PathBuf::from(pf).join("NVIDIA").join("CUDNN");
2431            if let Ok(read_dir) = std::fs::read_dir(&cudnn_root) {
2432                for entry in read_dir.flatten() {
2433                    let p = entry.path();
2434                    if !p.is_dir() {
2435                        continue;
2436                    }
2437                    let bin = p.join("bin");
2438                    let mut numbered: Vec<(u32, PathBuf)> = Vec::new();
2439                    let mut unnumbered: Vec<PathBuf> = Vec::new();
2440                    if let Ok(sub) = std::fs::read_dir(&bin) {
2441                        for s in sub.flatten() {
2442                            let sp = s.path();
2443                            if !sp.is_dir() {
2444                                continue;
2445                            }
2446                            match sp
2447                                .file_name()
2448                                .and_then(|n| n.to_str())
2449                                .and_then(|s| s.parse::<u32>().ok())
2450                            {
2451                                Some(n) => numbered.push((n, sp)),
2452                                None => unnumbered.push(sp),
2453                            }
2454                        }
2455                    }
2456                    // Match the running CUDA major when known.
2457                    if let Some(target) = target_major {
2458                        if let Some(pos) = numbered.iter().position(|(n, _)| *n == target) {
2459                            let (_, matched) = numbered.swap_remove(pos);
2460                            out.push(matched);
2461                        } else {
2462                            // No exact match — try highest <= target, then fall through.
2463                            numbered.sort_by_key(|b| std::cmp::Reverse(b.0));
2464                            if let Some(pos) = numbered.iter().position(|(n, _)| *n <= target) {
2465                                let (_, matched) = numbered.remove(pos);
2466                                out.push(matched);
2467                            } else if let Some((_, p)) = numbered.into_iter().next() {
2468                                out.push(p);
2469                            }
2470                        }
2471                    } else {
2472                        // No detection signal — push highest-major first so
2473                        // newest cuDNN is tried first, but include all as
2474                        // fallbacks so existing single-CUDA setups still work.
2475                        numbered.sort_by_key(|b| std::cmp::Reverse(b.0));
2476                        for (_, p) in numbered {
2477                            out.push(p);
2478                        }
2479                    }
2480                    // Non-numeric subdirs (rare; older cuDNN packagings)
2481                    // pass through unfiltered.
2482                    out.extend(unnumbered);
2483                }
2484            }
2485        }
2486    }
2487
2488    out
2489}
2490
2491/// cuDNN 9's main DLL (`cudnn64_9.dll`) is a facade that depends on several
2492/// companion DLLs in the same directory (`cudnn_ops64_9.dll`,
2493/// `cudnn_graph64_9.dll`, …). Windows resolves those dependencies via the
2494/// DLL search path, so we must ensure the cuDNN bin directory is on PATH
2495/// before `libloading` calls `LoadLibraryExW`.
2496#[cfg(target_os = "windows")]
2497fn ensure_cudnn_on_path(extra_dirs: &[PathBuf]) {
2498    use std::sync::OnceLock;
2499    static DONE: OnceLock<()> = OnceLock::new();
2500    DONE.get_or_init(|| {
2501        let existing = std::env::var("PATH").unwrap_or_default();
2502        let mut prefix = String::new();
2503        for dir in extra_dirs {
2504            if let Some(s) = dir.to_str() {
2505                if !existing.split(';').any(|p| p == s) {
2506                    if !prefix.is_empty() {
2507                        prefix.push(';');
2508                    }
2509                    prefix.push_str(s);
2510                }
2511            }
2512        }
2513        if !prefix.is_empty() {
2514            let new_path = if existing.is_empty() {
2515                prefix
2516            } else {
2517                format!("{prefix};{existing}")
2518            };
2519            // SAFETY: `set_var` is unsafe in Rust 2024 because mutating the
2520            // process environment races with concurrent reads from any other
2521            // thread that calls `getenv`. We're inside a `OnceLock::get_or_init`
2522            // so this runs at most once per process, before any cuDNN DLL is
2523            // loaded. The only readers we care about (Windows's DLL search
2524            // path inside `LoadLibraryExW`) come *after* this initialization.
2525            unsafe {
2526                std::env::set_var("PATH", new_path);
2527            }
2528        }
2529    });
2530}
2531
2532#[cfg(not(target_os = "windows"))]
2533fn ensure_cudnn_on_path(_extra_dirs: &[PathBuf]) {}
2534
2535/// Open libcudnn across the usual baracuda search paths plus cuDNN-specific ones.
2536fn open_cudnn() -> Result<Library, LoaderError> {
2537    let candidates: Vec<&'static str> = cudnn_candidates()
2538        .into_iter()
2539        .map(|s| Box::leak(s.into_boxed_str()) as &'static str)
2540        .collect();
2541    let candidates_leaked: &'static [&'static str] = Box::leak(candidates.into_boxed_slice());
2542
2543    // Make sure cuDNN-specific directories are on PATH so the main DLL can
2544    // find its companion DLLs via Windows' dependency resolver.
2545    let extra = cudnn_extra_search_dirs();
2546    ensure_cudnn_on_path(&extra);
2547
2548    // First try the standard baracuda search (now augmented on PATH).
2549    if let Ok(lib) = Library::open("cudnn", candidates_leaked) {
2550        return Ok(lib);
2551    }
2552
2553    // Then try cuDNN-specific directories explicitly.
2554    for dir in &extra {
2555        for candidate in candidates_leaked {
2556            let full = dir.join(candidate);
2557            if let Ok(lib) = Library::open_at("cudnn", &full) {
2558                return Ok(lib);
2559            }
2560        }
2561    }
2562
2563    Err(LoaderError::library_not_found_with_search(
2564        "cudnn",
2565        candidates_leaked,
2566        extra.len(),
2567    ))
2568}
2569
2570macro_rules! cudnn_fns {
2571    ($($name:ident as $sym:literal : $pfn:ty);* $(;)?) => {
2572        /// Dynamically-loaded cuDNN entry-point table.
2573        pub struct Cudnn {
2574            lib: Library,
2575            $($name: OnceLock<$pfn>,)*
2576        }
2577        impl core::fmt::Debug for Cudnn {
2578            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2579                f.debug_struct("Cudnn").field("lib", &self.lib).finish_non_exhaustive()
2580            }
2581        }
2582        impl Cudnn {
2583            $(
2584                #[doc = concat!("Resolve `", $sym, "`.")]
2585                pub fn $name(&self) -> Result<$pfn, LoaderError> {
2586                    if let Some(&p) = self.$name.get() { return Ok(p); }
2587                    let raw: *mut () = unsafe { self.lib.raw_symbol($sym)? };
2588                    let p: $pfn = unsafe { core::mem::transmute_copy::<*mut (), $pfn>(&raw) };
2589                    let _ = self.$name.set(p);
2590                    Ok(p)
2591                }
2592            )*
2593            fn empty(lib: Library) -> Self {
2594                Self { lib, $($name: OnceLock::new(),)* }
2595            }
2596        }
2597    };
2598}
2599
2600cudnn_fns! {
2601    // Handle + version
2602    cudnn_create as "cudnnCreate": PFN_cudnnCreate;
2603    cudnn_destroy as "cudnnDestroy": PFN_cudnnDestroy;
2604    cudnn_set_stream as "cudnnSetStream": PFN_cudnnSetStream;
2605    cudnn_get_version as "cudnnGetVersion": PFN_cudnnGetVersion;
2606    cudnn_get_error_string as "cudnnGetErrorString": PFN_cudnnGetErrorString;
2607    // Tensor descriptor
2608    cudnn_create_tensor_descriptor as "cudnnCreateTensorDescriptor": PFN_cudnnCreateTensorDescriptor;
2609    cudnn_destroy_tensor_descriptor as "cudnnDestroyTensorDescriptor": PFN_cudnnDestroyTensorDescriptor;
2610    cudnn_set_tensor_4d_descriptor as "cudnnSetTensor4dDescriptor": PFN_cudnnSetTensor4dDescriptor;
2611    // Activation
2612    cudnn_create_activation_descriptor as "cudnnCreateActivationDescriptor": PFN_cudnnCreateActivationDescriptor;
2613    cudnn_destroy_activation_descriptor as "cudnnDestroyActivationDescriptor": PFN_cudnnDestroyActivationDescriptor;
2614    cudnn_set_activation_descriptor as "cudnnSetActivationDescriptor": PFN_cudnnSetActivationDescriptor;
2615    cudnn_activation_forward as "cudnnActivationForward": PFN_cudnnActivationForward;
2616    // Convolution
2617    cudnn_create_filter_descriptor as "cudnnCreateFilterDescriptor": PFN_cudnnCreateFilterDescriptor;
2618    cudnn_destroy_filter_descriptor as "cudnnDestroyFilterDescriptor": PFN_cudnnDestroyFilterDescriptor;
2619    cudnn_set_filter_4d_descriptor as "cudnnSetFilter4dDescriptor": PFN_cudnnSetFilter4dDescriptor;
2620    cudnn_create_convolution_descriptor as "cudnnCreateConvolutionDescriptor": PFN_cudnnCreateConvolutionDescriptor;
2621    cudnn_destroy_convolution_descriptor as "cudnnDestroyConvolutionDescriptor": PFN_cudnnDestroyConvolutionDescriptor;
2622    cudnn_set_convolution_2d_descriptor as "cudnnSetConvolution2dDescriptor": PFN_cudnnSetConvolution2dDescriptor;
2623    cudnn_get_convolution_2d_forward_output_dim as "cudnnGetConvolution2dForwardOutputDim": PFN_cudnnGetConvolution2dForwardOutputDim;
2624    cudnn_get_convolution_forward_workspace_size as "cudnnGetConvolutionForwardWorkspaceSize": PFN_cudnnGetConvolutionForwardWorkspaceSize;
2625    cudnn_convolution_forward as "cudnnConvolutionForward": PFN_cudnnConvolutionForward;
2626    cudnn_convolution_backward_data as "cudnnConvolutionBackwardData": PFN_cudnnConvolutionBackwardData;
2627    cudnn_convolution_backward_filter as "cudnnConvolutionBackwardFilter": PFN_cudnnConvolutionBackwardFilter;
2628    cudnn_convolution_backward_bias as "cudnnConvolutionBackwardBias": PFN_cudnnConvolutionBackwardBias;
2629    cudnn_get_convolution_backward_data_workspace_size as "cudnnGetConvolutionBackwardDataWorkspaceSize": PFN_cudnnGetConvolutionBackwardDataWorkspaceSize;
2630    cudnn_get_convolution_backward_filter_workspace_size as "cudnnGetConvolutionBackwardFilterWorkspaceSize": PFN_cudnnGetConvolutionBackwardFilterWorkspaceSize;
2631    // Pooling
2632    cudnn_create_pooling_descriptor as "cudnnCreatePoolingDescriptor": PFN_cudnnCreatePoolingDescriptor;
2633    cudnn_destroy_pooling_descriptor as "cudnnDestroyPoolingDescriptor": PFN_cudnnDestroyPoolingDescriptor;
2634    cudnn_set_pooling_2d_descriptor as "cudnnSetPooling2dDescriptor": PFN_cudnnSetPooling2dDescriptor;
2635    cudnn_pooling_forward as "cudnnPoolingForward": PFN_cudnnPoolingForward;
2636    cudnn_pooling_backward as "cudnnPoolingBackward": PFN_cudnnPoolingBackward;
2637    // Softmax
2638    cudnn_softmax_forward as "cudnnSoftmaxForward": PFN_cudnnSoftmaxForward;
2639    cudnn_softmax_backward as "cudnnSoftmaxBackward": PFN_cudnnSoftmaxBackward;
2640    // BatchNorm
2641    cudnn_batch_normalization_forward_inference as "cudnnBatchNormalizationForwardInference": PFN_cudnnBatchNormalizationForwardInference;
2642    cudnn_batch_normalization_forward_training as "cudnnBatchNormalizationForwardTraining": PFN_cudnnBatchNormalizationForwardTraining;
2643    cudnn_batch_normalization_backward as "cudnnBatchNormalizationBackward": PFN_cudnnBatchNormalizationBackward;
2644    // Op-tensor / reduce / transform
2645    cudnn_create_op_tensor_descriptor as "cudnnCreateOpTensorDescriptor": PFN_cudnnCreateOpTensorDescriptor;
2646    cudnn_destroy_op_tensor_descriptor as "cudnnDestroyOpTensorDescriptor": PFN_cudnnDestroyOpTensorDescriptor;
2647    cudnn_set_op_tensor_descriptor as "cudnnSetOpTensorDescriptor": PFN_cudnnSetOpTensorDescriptor;
2648    cudnn_op_tensor as "cudnnOpTensor": PFN_cudnnOpTensor;
2649    cudnn_create_reduce_tensor_descriptor as "cudnnCreateReduceTensorDescriptor": PFN_cudnnCreateReduceTensorDescriptor;
2650    cudnn_destroy_reduce_tensor_descriptor as "cudnnDestroyReduceTensorDescriptor": PFN_cudnnDestroyReduceTensorDescriptor;
2651    cudnn_set_reduce_tensor_descriptor as "cudnnSetReduceTensorDescriptor": PFN_cudnnSetReduceTensorDescriptor;
2652    cudnn_get_reduction_workspace_size as "cudnnGetReductionWorkspaceSize": PFN_cudnnGetReductionWorkspaceSize;
2653    cudnn_reduce_tensor as "cudnnReduceTensor": PFN_cudnnReduceTensor;
2654    cudnn_add_tensor as "cudnnAddTensor": PFN_cudnnAddTensor;
2655    cudnn_transform_tensor as "cudnnTransformTensor": PFN_cudnnTransformTensor;
2656    cudnn_scale_tensor as "cudnnScaleTensor": PFN_cudnnScaleTensor;
2657    cudnn_set_tensor as "cudnnSetTensor": PFN_cudnnSetTensor;
2658    // LRN
2659    cudnn_create_lrn_descriptor as "cudnnCreateLRNDescriptor": PFN_cudnnCreateLRNDescriptor;
2660    cudnn_destroy_lrn_descriptor as "cudnnDestroyLRNDescriptor": PFN_cudnnDestroyLRNDescriptor;
2661    cudnn_set_lrn_descriptor as "cudnnSetLRNDescriptor": PFN_cudnnSetLRNDescriptor;
2662    cudnn_lrn_cross_channel_forward as "cudnnLRNCrossChannelForward": PFN_cudnnLRNCrossChannelForward;
2663    // Dropout
2664    cudnn_create_dropout_descriptor as "cudnnCreateDropoutDescriptor": PFN_cudnnCreateDropoutDescriptor;
2665    cudnn_destroy_dropout_descriptor as "cudnnDestroyDropoutDescriptor": PFN_cudnnDestroyDropoutDescriptor;
2666    cudnn_dropout_get_states_size as "cudnnDropoutGetStatesSize": PFN_cudnnDropoutGetStatesSize;
2667    cudnn_dropout_get_reserve_space_size as "cudnnDropoutGetReserveSpaceSize": PFN_cudnnDropoutGetReserveSpaceSize;
2668    cudnn_set_dropout_descriptor as "cudnnSetDropoutDescriptor": PFN_cudnnSetDropoutDescriptor;
2669    cudnn_dropout_forward as "cudnnDropoutForward": PFN_cudnnDropoutForward;
2670    cudnn_dropout_backward as "cudnnDropoutBackward": PFN_cudnnDropoutBackward;
2671    // RNN
2672    cudnn_create_rnn_descriptor as "cudnnCreateRNNDescriptor": PFN_cudnnCreateRNNDescriptor;
2673    cudnn_destroy_rnn_descriptor as "cudnnDestroyRNNDescriptor": PFN_cudnnDestroyRNNDescriptor;
2674    cudnn_create_rnn_data_descriptor as "cudnnCreateRNNDataDescriptor": PFN_cudnnCreateRNNDataDescriptor;
2675    cudnn_destroy_rnn_data_descriptor as "cudnnDestroyRNNDataDescriptor": PFN_cudnnDestroyRNNDataDescriptor;
2676    cudnn_rnn_forward as "cudnnRNNForward": PFN_cudnnRNNForward;
2677    // Graph / backend API
2678    cudnn_backend_create_descriptor as "cudnnBackendCreateDescriptor": PFN_cudnnBackendCreateDescriptor;
2679    cudnn_backend_destroy_descriptor as "cudnnBackendDestroyDescriptor": PFN_cudnnBackendDestroyDescriptor;
2680    cudnn_backend_initialize as "cudnnBackendInitialize": PFN_cudnnBackendInitialize;
2681    cudnn_backend_finalize as "cudnnBackendFinalize": PFN_cudnnBackendFinalize;
2682    cudnn_backend_set_attribute as "cudnnBackendSetAttribute": PFN_cudnnBackendSetAttribute;
2683    cudnn_backend_get_attribute as "cudnnBackendGetAttribute": PFN_cudnnBackendGetAttribute;
2684    cudnn_backend_execute as "cudnnBackendExecute": PFN_cudnnBackendExecute;
2685    // Nd descriptors
2686    cudnn_set_tensor_nd_descriptor as "cudnnSetTensorNdDescriptor": PFN_cudnnSetTensorNdDescriptor;
2687    cudnn_get_tensor_nd_descriptor as "cudnnGetTensorNdDescriptor": PFN_cudnnGetTensorNdDescriptor;
2688    cudnn_set_filter_nd_descriptor as "cudnnSetFilterNdDescriptor": PFN_cudnnSetFilterNdDescriptor;
2689    cudnn_set_convolution_nd_descriptor as "cudnnSetConvolutionNdDescriptor": PFN_cudnnSetConvolutionNdDescriptor;
2690    cudnn_set_pooling_nd_descriptor as "cudnnSetPoolingNdDescriptor": PFN_cudnnSetPoolingNdDescriptor;
2691    // CTC
2692    cudnn_create_ctc_loss_descriptor as "cudnnCreateCTCLossDescriptor": PFN_cudnnCreateCTCLossDescriptor;
2693    cudnn_destroy_ctc_loss_descriptor as "cudnnDestroyCTCLossDescriptor": PFN_cudnnDestroyCTCLossDescriptor;
2694    cudnn_set_ctc_loss_descriptor as "cudnnSetCTCLossDescriptor": PFN_cudnnSetCTCLossDescriptor;
2695    cudnn_get_ctc_loss_workspace_size as "cudnnGetCTCLossWorkspaceSize": PFN_cudnnGetCTCLossWorkspaceSize;
2696    cudnn_ctc_loss as "cudnnCTCLoss": PFN_cudnnCTCLoss;
2697    // RNN backward
2698    cudnn_rnn_backward_data_v8 as "cudnnRNNBackwardData_v8": PFN_cudnnRNNBackwardData_v8;
2699    cudnn_rnn_backward_weights_v8 as "cudnnRNNBackwardWeights_v8": PFN_cudnnRNNBackwardWeights_v8;
2700    // Spatial transformer
2701    cudnn_create_spatial_transformer_descriptor as "cudnnCreateSpatialTransformerDescriptor": PFN_cudnnCreateSpatialTransformerDescriptor;
2702    cudnn_destroy_spatial_transformer_descriptor as "cudnnDestroySpatialTransformerDescriptor": PFN_cudnnDestroySpatialTransformerDescriptor;
2703    cudnn_set_spatial_transformer_nd_descriptor as "cudnnSetSpatialTransformerNdDescriptor": PFN_cudnnSetSpatialTransformerNdDescriptor;
2704    cudnn_spatial_tf_grid_generator_forward as "cudnnSpatialTfGridGeneratorForward": PFN_cudnnSpatialTfGridGeneratorForward;
2705    cudnn_spatial_tf_sampler_forward as "cudnnSpatialTfSamplerForward": PFN_cudnnSpatialTfSamplerForward;
2706
2707    // Tier 1 — convolution + activation + reduction misc
2708    cudnn_set_convolution_group_count as "cudnnSetConvolutionGroupCount": PFN_cudnnSetConvolutionGroupCount;
2709    cudnn_get_convolution_group_count as "cudnnGetConvolutionGroupCount": PFN_cudnnGetConvolutionGroupCount;
2710    cudnn_set_convolution_math_type as "cudnnSetConvolutionMathType": PFN_cudnnSetConvolutionMathType;
2711    cudnn_get_convolution_math_type as "cudnnGetConvolutionMathType": PFN_cudnnGetConvolutionMathType;
2712    cudnn_set_convolution_reorder_type as "cudnnSetConvolutionReorderType": PFN_cudnnSetConvolutionReorderType;
2713    cudnn_get_convolution_reorder_type as "cudnnGetConvolutionReorderType": PFN_cudnnGetConvolutionReorderType;
2714    cudnn_reorder_filter_and_bias as "cudnnReorderFilterAndBias": PFN_cudnnReorderFilterAndBias;
2715    cudnn_convolution_bias_activation_forward as "cudnnConvolutionBiasActivationForward": PFN_cudnnConvolutionBiasActivationForward;
2716    cudnn_activation_backward as "cudnnActivationBackward": PFN_cudnnActivationBackward;
2717    cudnn_set_activation_descriptor_swish_beta as "cudnnSetActivationDescriptorSwishBeta": PFN_cudnnSetActivationDescriptorSwishBeta;
2718    cudnn_get_activation_descriptor_swish_beta as "cudnnGetActivationDescriptorSwishBeta": PFN_cudnnGetActivationDescriptorSwishBeta;
2719    cudnn_lrn_cross_channel_backward as "cudnnLRNCrossChannelBackward": PFN_cudnnLRNCrossChannelBackward;
2720    cudnn_divisive_normalization_forward as "cudnnDivisiveNormalizationForward": PFN_cudnnDivisiveNormalizationForward;
2721    cudnn_divisive_normalization_backward as "cudnnDivisiveNormalizationBackward": PFN_cudnnDivisiveNormalizationBackward;
2722    cudnn_get_reduction_indices_size as "cudnnGetReductionIndicesSize": PFN_cudnnGetReductionIndicesSize;
2723    cudnn_set_tensor_4d_descriptor_ex as "cudnnSetTensor4dDescriptorEx": PFN_cudnnSetTensor4dDescriptorEx;
2724    cudnn_get_tensor_4d_descriptor as "cudnnGetTensor4dDescriptor": PFN_cudnnGetTensor4dDescriptor;
2725    cudnn_get_filter_4d_descriptor as "cudnnGetFilter4dDescriptor": PFN_cudnnGetFilter4dDescriptor;
2726    cudnn_get_dropout_descriptor as "cudnnGetDropoutDescriptor": PFN_cudnnGetDropoutDescriptor;
2727    cudnn_restore_dropout_descriptor as "cudnnRestoreDropoutDescriptor": PFN_cudnnRestoreDropoutDescriptor;
2728
2729    // Tier 2 — algorithm finders / pickers (v7)
2730    cudnn_get_convolution_forward_algorithm_v7 as "cudnnGetConvolutionForwardAlgorithm_v7": PFN_cudnnGetConvolutionForwardAlgorithm_v7;
2731    cudnn_find_convolution_forward_algorithm as "cudnnFindConvolutionForwardAlgorithm": PFN_cudnnFindConvolutionForwardAlgorithm;
2732    cudnn_find_convolution_forward_algorithm_ex as "cudnnFindConvolutionForwardAlgorithmEx": PFN_cudnnFindConvolutionForwardAlgorithmEx;
2733    cudnn_get_convolution_backward_data_algorithm_v7 as "cudnnGetConvolutionBackwardDataAlgorithm_v7": PFN_cudnnGetConvolutionBackwardDataAlgorithm_v7;
2734    cudnn_find_convolution_backward_data_algorithm as "cudnnFindConvolutionBackwardDataAlgorithm": PFN_cudnnFindConvolutionBackwardDataAlgorithm;
2735    cudnn_get_convolution_backward_filter_algorithm_v7 as "cudnnGetConvolutionBackwardFilterAlgorithm_v7": PFN_cudnnGetConvolutionBackwardFilterAlgorithm_v7;
2736    cudnn_find_convolution_backward_filter_algorithm as "cudnnFindConvolutionBackwardFilterAlgorithm": PFN_cudnnFindConvolutionBackwardFilterAlgorithm;
2737
2738    // Tier 3 — BatchNorm Ex + generic Normalization
2739    cudnn_get_batch_normalization_forward_training_ex_workspace_size as "cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize": PFN_cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize;
2740    cudnn_get_batch_normalization_backward_ex_workspace_size as "cudnnGetBatchNormalizationBackwardExWorkspaceSize": PFN_cudnnGetBatchNormalizationBackwardExWorkspaceSize;
2741    cudnn_get_batch_normalization_training_ex_reserve_space_size as "cudnnGetBatchNormalizationTrainingExReserveSpaceSize": PFN_cudnnGetBatchNormalizationTrainingExReserveSpaceSize;
2742    cudnn_batch_normalization_forward_training_ex as "cudnnBatchNormalizationForwardTrainingEx": PFN_cudnnBatchNormalizationForwardTrainingEx;
2743    cudnn_batch_normalization_backward_ex as "cudnnBatchNormalizationBackwardEx": PFN_cudnnBatchNormalizationBackwardEx;
2744    cudnn_normalization_forward_inference as "cudnnNormalizationForwardInference": PFN_cudnnNormalizationForwardInference;
2745    cudnn_get_normalization_forward_training_workspace_size as "cudnnGetNormalizationForwardTrainingWorkspaceSize": PFN_cudnnGetNormalizationForwardTrainingWorkspaceSize;
2746    cudnn_get_normalization_backward_workspace_size as "cudnnGetNormalizationBackwardWorkspaceSize": PFN_cudnnGetNormalizationBackwardWorkspaceSize;
2747    cudnn_get_normalization_training_reserve_space_size as "cudnnGetNormalizationTrainingReserveSpaceSize": PFN_cudnnGetNormalizationTrainingReserveSpaceSize;
2748    cudnn_normalization_forward_training as "cudnnNormalizationForwardTraining": PFN_cudnnNormalizationForwardTraining;
2749    cudnn_normalization_backward as "cudnnNormalizationBackward": PFN_cudnnNormalizationBackward;
2750
2751    // Tier 4 — RNN v8
2752    cudnn_set_rnn_descriptor_v8 as "cudnnSetRNNDescriptor_v8": PFN_cudnnSetRNNDescriptor_v8;
2753    cudnn_build_rnn_dynamic as "cudnnBuildRNNDynamic": PFN_cudnnBuildRNNDynamic;
2754    cudnn_get_rnn_temp_space_sizes as "cudnnGetRNNTempSpaceSizes": PFN_cudnnGetRNNTempSpaceSizes;
2755    cudnn_get_rnn_weight_space_size as "cudnnGetRNNWeightSpaceSize": PFN_cudnnGetRNNWeightSpaceSize;
2756    cudnn_get_rnn_weight_params as "cudnnGetRNNWeightParams": PFN_cudnnGetRNNWeightParams;
2757
2758    // Tier 5 — Multi-head attention
2759    cudnn_create_attn_descriptor as "cudnnCreateAttnDescriptor": PFN_cudnnCreateAttnDescriptor;
2760    cudnn_destroy_attn_descriptor as "cudnnDestroyAttnDescriptor": PFN_cudnnDestroyAttnDescriptor;
2761    cudnn_set_attn_descriptor as "cudnnSetAttnDescriptor": PFN_cudnnSetAttnDescriptor;
2762    cudnn_get_multi_head_attn_buffers as "cudnnGetMultiHeadAttnBuffers": PFN_cudnnGetMultiHeadAttnBuffers;
2763    cudnn_get_multi_head_attn_weights as "cudnnGetMultiHeadAttnWeights": PFN_cudnnGetMultiHeadAttnWeights;
2764    cudnn_multi_head_attn_forward as "cudnnMultiHeadAttnForward": PFN_cudnnMultiHeadAttnForward;
2765    cudnn_multi_head_attn_backward_data as "cudnnMultiHeadAttnBackwardData": PFN_cudnnMultiHeadAttnBackwardData;
2766    cudnn_multi_head_attn_backward_weights as "cudnnMultiHeadAttnBackwardWeights": PFN_cudnnMultiHeadAttnBackwardWeights;
2767    cudnn_create_seq_data_descriptor as "cudnnCreateSeqDataDescriptor": PFN_cudnnCreateSeqDataDescriptor;
2768    cudnn_destroy_seq_data_descriptor as "cudnnDestroySeqDataDescriptor": PFN_cudnnDestroySeqDataDescriptor;
2769    cudnn_set_seq_data_descriptor as "cudnnSetSeqDataDescriptor": PFN_cudnnSetSeqDataDescriptor;
2770}
2771
2772/// Lazily-initialized process-wide cuDNN loader singleton.
2773pub fn cudnn() -> Result<&'static Cudnn, LoaderError> {
2774    static CUDNN: OnceLock<Cudnn> = OnceLock::new();
2775    if let Some(c) = CUDNN.get() {
2776        return Ok(c);
2777    }
2778    let lib = open_cudnn()?;
2779    let c = Cudnn::empty(lib);
2780    let _ = CUDNN.set(c);
2781    Ok(CUDNN.get().expect("OnceLock set or lost race"))
2782}
2783
2784#[cfg(test)]
2785mod search_dir_tests {
2786    use super::*;
2787
2788    #[test]
2789    fn parse_cuda_major_handles_typical_windows_paths() {
2790        assert_eq!(
2791            parse_cuda_major_from_path(
2792                r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6"
2793            ),
2794            Some(12),
2795        );
2796        assert_eq!(
2797            parse_cuda_major_from_path(
2798                r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin"
2799            ),
2800            Some(11),
2801        );
2802        // Linux-style.
2803        assert_eq!(parse_cuda_major_from_path("/opt/cuda/v13.0"), Some(13));
2804    }
2805
2806    #[test]
2807    fn parse_cuda_major_ignores_unrelated_v_prefixed_words() {
2808        // No `vN.M` segment — `verbose` happens to start with a `v`
2809        // but doesn't match the `vN.M` pattern.
2810        assert_eq!(
2811            parse_cuda_major_from_path("/usr/local/verbose/cuda"),
2812            None,
2813        );
2814        assert_eq!(parse_cuda_major_from_path(""), None);
2815        assert_eq!(parse_cuda_major_from_path("/usr/local/cuda"), None);
2816    }
2817
2818    #[test]
2819    fn parse_cuda_major_accepts_uppercase_v() {
2820        assert_eq!(
2821            parse_cuda_major_from_path(r"D:\NVIDIA\CUDA\V12.6\bin"),
2822            Some(12),
2823        );
2824    }
2825
2826    /// Reproduces the multi-CUDA bug: with the *old* logic, all `bin/<n>`
2827    /// subdirs would be pushed in directory-iteration order (effectively
2828    /// arbitrary). Now we should pick the one matching the detected
2829    /// major.
2830    #[test]
2831    fn dir_selection_prefers_target_major() {
2832        // Simulate the inner numbered-subdir filter inline. We don't
2833        // touch the filesystem; we just verify the policy.
2834        let numbered: Vec<(u32, &str)> = vec![(11, "/cudnn/bin/11"), (12, "/cudnn/bin/12")];
2835        let target = Some(12u32);
2836
2837        let chosen: Vec<&str> = match target {
2838            Some(t) => numbered
2839                .iter()
2840                .find(|(n, _)| *n == t)
2841                .map(|(_, p)| *p)
2842                .into_iter()
2843                .collect(),
2844            None => numbered.iter().map(|(_, p)| *p).collect(),
2845        };
2846        assert_eq!(chosen, vec!["/cudnn/bin/12"]);
2847    }
2848
2849    #[test]
2850    fn dir_selection_falls_back_to_highest_le_target() {
2851        // Detected major = 13 but only cuDNN/12 + cuDNN/11 are installed.
2852        // Highest-<=-target wins.
2853        let mut numbered: Vec<(u32, &str)> = vec![(11, "/cudnn/11"), (12, "/cudnn/12")];
2854        let target = 13u32;
2855
2856        // Replicate the policy: exact match → take it; else highest <= target.
2857        let result = match numbered.iter().position(|(n, _)| *n == target) {
2858            Some(_pos) => unreachable!("no exact match in this scenario"),
2859            None => {
2860                numbered.sort_by_key(|b| std::cmp::Reverse(b.0));
2861                numbered
2862                    .iter()
2863                    .find(|(n, _)| *n <= target)
2864                    .map(|(_, p)| *p)
2865            }
2866        };
2867        assert_eq!(result, Some("/cudnn/12"));
2868    }
2869
2870    #[test]
2871    fn dir_selection_no_signal_is_highest_first() {
2872        let mut numbered: Vec<(u32, &str)> = vec![(11, "/11"), (13, "/13"), (12, "/12")];
2873        numbered.sort_by_key(|b| std::cmp::Reverse(b.0));
2874        let order: Vec<&str> = numbered.iter().map(|(_, p)| *p).collect();
2875        assert_eq!(order, vec!["/13", "/12", "/11"]);
2876    }
2877}