Skip to main content

baracuda_cudnn/
lib.rs

1//! Safe Rust wrappers for NVIDIA cuDNN.
2//!
3//! Layered on top of [`baracuda-cudnn-sys`](https://docs.rs/baracuda-cudnn-sys).
4//! Use this crate directly for typed, RAII-managed cuDNN handles +
5//! descriptors; reach for `-sys` only when adding a function the safe
6//! layer doesn't expose yet.
7//!
8//! ## Scope
9//!
10//! Covers the cuDNN classic API surface that `baracuda-kernels`'s
11//! Phase 7+ Conv2d / Pool2d / CTCLoss / BatchNorm / GroupNorm plans
12//! and the Phase 11 Conv1d/3d/Transpose/depthwise + Pool1d/3d/Adaptive
13//! fanout dispatch through. Concretely:
14//!
15//! - Handle management + stream binding.
16//! - Tensor / filter / convolution / pooling / activation /
17//!   batch-norm / RNN / dropout / op-tensor / reduce-tensor /
18//!   LRN / SpatialTransform / Attn descriptors.
19//! - Conv2d / Conv1d / Conv3d (FW + BW data + BW weight) with all
20//!   algo enums.
21//! - Pool2d / Pool1d / Pool3d (Avg + Max, deterministic + non-det).
22//! - BatchNorm FW training/inference + BW + persistent mode.
23//! - LRN, Softmax (classic — modern softmax is bespoke in
24//!   `baracuda-kernels`).
25//! - CTC loss FW + BW (the cuDNN path; bespoke
26//!   `baracuda-kernels::CtcLossPlan` covers the non-cuDNN path).
27//! - Op-tensor + reduce-tensor (gluing primitives for fused ops).
28//! - RNN classic API (cells, sequences, persistent).
29//! - DropoutDescriptor + state management.
30//!
31//! The cuDNN **backend / graph API** (the modern fusion API) is NOT
32//! wrapped here — `baracuda-kernels` builds bespoke fused kernels
33//! directly via `baracuda-kernels-sys` for the ops where graph-API
34//! fusion would be the win, so the maintenance cost of wrapping the
35//! graph API duplicate hasn't been justified yet.
36//!
37//! ## Build requirement
38//!
39//! cuDNN is a **separate NVIDIA download** not bundled with the stock
40//! CUDA toolkit. The `baracuda-kernels-sys` build script auto-discovers
41//! it via `CUDNN_PATH` / `CUDNN_ROOT` / `CUDNN_HOME` env vars or the
42//! standard Windows / Linux install paths — see the workspace
43//! [`README.md`](https://github.com/ciresnave/baracuda#building)
44//! "Building" section for the full probe order.
45
46#![warn(missing_debug_implementations)]
47
48use baracuda_cudnn_sys::{
49    cudnn, cudnnActivationDescriptor_t, cudnnActivationMode_t, cudnnAttnDescriptor_t,
50    cudnnBackendAttributeName_t, cudnnBackendAttributeType_t, cudnnBackendDescriptorType_t,
51    cudnnBackendDescriptor_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t,
52    cudnnConvolutionBwdDataAlgo_t, cudnnConvolutionBwdFilterAlgo_t,
53    cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, cudnnConvolutionMode_t,
54    cudnnDataType_t, cudnnDropoutDescriptor_t, cudnnFilterDescriptor_t,
55    cudnnHandle_t, cudnnIndicesType_t, cudnnLRNDescriptor_t, cudnnMathType_t, cudnnNanPropagation_t,
56    cudnnNormAlgo_t, cudnnNormMode_t, cudnnNormOps_t, cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t,
57    cudnnPoolingDescriptor_t, cudnnPoolingMode_t, cudnnReduceTensorDescriptor_t,
58    cudnnReduceTensorIndices_t, cudnnReduceTensorOp_t, cudnnReorderType_t,
59    cudnnSeqDataDescriptor_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, cudnnStatus_t,
60    cudnnTensorDescriptor_t, cudnnTensorFormat_t,
61};
62use baracuda_driver::{DeviceBuffer, Stream};
63use baracuda_types::DeviceRepr;
64
65/// Error type for cuDNN operations.
66pub type Error = baracuda_core::Error<cudnnStatus_t>;
67/// Result alias.
68pub type Result<T, E = Error> = core::result::Result<T, E>;
69
70#[inline]
71fn check(status: cudnnStatus_t) -> Result<()> {
72    Error::check(status)
73}
74
75/// cuDNN context handle.
76pub struct Handle {
77    handle: cudnnHandle_t,
78}
79
80unsafe impl Send for Handle {}
81
82impl core::fmt::Debug for Handle {
83    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
84        f.debug_struct("cudnn::Handle")
85            .field("handle", &self.handle)
86            .finish()
87    }
88}
89
90impl Handle {
91    /// Create a new cuDNN handle.
92    pub fn new() -> Result<Self> {
93        let c = cudnn()?;
94        let cu = c.cudnn_create()?;
95        let mut h: cudnnHandle_t = core::ptr::null_mut();
96        check(unsafe { cu(&mut h) })?;
97        Ok(Self { handle: h })
98    }
99
100    /// Bind operations on this handle to `stream`.
101    pub fn set_stream(&self, stream: &Stream) -> Result<()> {
102        let c = cudnn()?;
103        let cu = c.cudnn_set_stream()?;
104        check(unsafe { cu(self.handle, stream.as_raw() as _) })
105    }
106
107    /// Raw handle.
108    #[inline]
109    pub fn as_raw(&self) -> cudnnHandle_t {
110        self.handle
111    }
112}
113
114impl Drop for Handle {
115    fn drop(&mut self) {
116        if let Ok(c) = cudnn() {
117            if let Ok(cu) = c.cudnn_destroy() {
118                let _ = unsafe { cu(self.handle) };
119            }
120        }
121    }
122}
123
124/// cuDNN library version as a packed integer (e.g. `9106` for 9.1.6).
125///
126/// Does **not** require an initialized handle.
127pub fn version() -> Result<usize> {
128    let c = cudnn()?;
129    let cu = c.cudnn_get_version()?;
130    // SAFETY: cudnnGetVersion has no error path.
131    Ok(unsafe { cu() })
132}
133
134/// Element dtype for a tensor.
135#[derive(Copy, Clone, Debug, Eq, PartialEq)]
136pub enum DType {
137    /// Single-precision 32-bit floating point.
138    F32,
139    /// Double-precision 64-bit floating point.
140    F64,
141    /// IEEE 754 half-precision (16-bit) floating point.
142    F16,
143    /// Brain half-precision (16-bit) floating point.
144    BF16,
145    /// 8-bit signed integer (quantized inference).
146    I8,
147    /// 32-bit signed integer (integer accumulators).
148    I32,
149}
150
151impl DType {
152    fn raw(self) -> cudnnDataType_t {
153        match self {
154            DType::F32 => cudnnDataType_t::Float,
155            DType::F64 => cudnnDataType_t::Double,
156            DType::F16 => cudnnDataType_t::Half,
157            DType::BF16 => cudnnDataType_t::BFloat16,
158            DType::I8 => cudnnDataType_t::Int8,
159            DType::I32 => cudnnDataType_t::Int32,
160        }
161    }
162}
163
164/// Trait mapping Rust element types to their cuDNN [`DType`] tag.
165///
166/// Lets generic code accept "a tensor of T" and recover the cuDNN dtype
167/// with `T::DTYPE`, instead of threading a `DType` argument through every
168/// call. Useful for tensor-descriptor builders:
169///
170/// ```no_run
171/// use baracuda_cudnn::{CudnnDataType, DType, TensorDescriptor, TensorFormat};
172///
173/// fn make_nchw<T: CudnnDataType>(n: i32, c: i32, h: i32, w: i32)
174///     -> baracuda_cudnn::Result<TensorDescriptor>
175/// {
176///     TensorDescriptor::new_4d(TensorFormat::Nchw, T::DTYPE, n, c, h, w)
177/// }
178///
179/// let desc = make_nchw::<f32>(1, 3, 224, 224)?;
180/// # Ok::<(), baracuda_cudnn::Error>(())
181/// ```
182///
183/// Implementors: `f32`, `f64`, [`baracuda_types::Half`],
184/// [`baracuda_types::BFloat16`], `i8`, `i32`.
185pub trait CudnnDataType: DeviceRepr + Copy + 'static {
186    /// The [`DType`] tag cuDNN uses for this scalar type.
187    const DTYPE: DType;
188}
189
190impl CudnnDataType for f32 {
191    const DTYPE: DType = DType::F32;
192}
193impl CudnnDataType for f64 {
194    const DTYPE: DType = DType::F64;
195}
196impl CudnnDataType for baracuda_types::Half {
197    const DTYPE: DType = DType::F16;
198}
199impl CudnnDataType for baracuda_types::BFloat16 {
200    const DTYPE: DType = DType::BF16;
201}
202impl CudnnDataType for i8 {
203    const DTYPE: DType = DType::I8;
204}
205impl CudnnDataType for i32 {
206    const DTYPE: DType = DType::I32;
207}
208
209// Direct impls on the `half` crate's types so callers can use
210// `half::f16` / `half::bf16` end-to-end without bridging through
211// `baracuda_types::Half` / `BFloat16`. Both directions of `From` are
212// already available in baracuda-types under the same feature.
213#[cfg(feature = "half-crate")]
214impl CudnnDataType for half::f16 {
215    const DTYPE: DType = DType::F16;
216}
217#[cfg(feature = "half-crate")]
218impl CudnnDataType for half::bf16 {
219    const DTYPE: DType = DType::BF16;
220}
221
222/// Memory layout for a 4-D tensor.
223#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
224pub enum TensorFormat {
225    /// Batch × Channels × Height × Width (channels-first, the cuDNN default).
226    #[default]
227    Nchw,
228    /// Batch × Height × Width × Channels (channels-last).
229    Nhwc,
230}
231
232impl TensorFormat {
233    fn raw(self) -> cudnnTensorFormat_t {
234        match self {
235            TensorFormat::Nchw => cudnnTensorFormat_t::Nchw,
236            TensorFormat::Nhwc => cudnnTensorFormat_t::Nhwc,
237        }
238    }
239}
240
241/// A 4-D tensor descriptor.
242pub struct TensorDescriptor {
243    desc: cudnnTensorDescriptor_t,
244}
245
246unsafe impl Send for TensorDescriptor {}
247
248impl core::fmt::Debug for TensorDescriptor {
249    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
250        f.debug_struct("TensorDescriptor")
251            .field("desc", &self.desc)
252            .finish_non_exhaustive()
253    }
254}
255
256impl TensorDescriptor {
257    /// Describe an `N × C × H × W` tensor with the given format and dtype.
258    pub fn new_4d(
259        format: TensorFormat,
260        dtype: DType,
261        n: i32,
262        c: i32,
263        h: i32,
264        w: i32,
265    ) -> Result<Self> {
266        let cu_crate = cudnn()?;
267        let create = cu_crate.cudnn_create_tensor_descriptor()?;
268        let set = cu_crate.cudnn_set_tensor_4d_descriptor()?;
269        let mut desc: cudnnTensorDescriptor_t = core::ptr::null_mut();
270        check(unsafe { create(&mut desc) })?;
271        let this = Self { desc };
272        check(unsafe { set(this.desc, format.raw(), dtype.raw(), n, c, h, w) })?;
273        Ok(this)
274    }
275
276    /// Describe an N-dimensional tensor. `dims` and `strides` must have the
277    /// same length (≤8) and correspond to a valid, non-overlapping
278    /// cuDNN-supported layout.
279    pub fn new_nd(dtype: DType, dims: &[i32], strides: &[i32]) -> Result<Self> {
280        assert_eq!(
281            dims.len(),
282            strides.len(),
283            "dims/strides length mismatch for Nd tensor descriptor"
284        );
285        let cu_crate = cudnn()?;
286        let create = cu_crate.cudnn_create_tensor_descriptor()?;
287        let set = cu_crate.cudnn_set_tensor_nd_descriptor()?;
288        let mut desc: cudnnTensorDescriptor_t = core::ptr::null_mut();
289        check(unsafe { create(&mut desc) })?;
290        let this = Self { desc };
291        check(unsafe {
292            set(
293                this.desc,
294                dtype.raw(),
295                dims.len() as core::ffi::c_int,
296                dims.as_ptr(),
297                strides.as_ptr(),
298            )
299        })?;
300        Ok(this)
301    }
302
303    /// Raw descriptor. Use with care.
304    #[inline]
305    pub fn as_raw(&self) -> cudnnTensorDescriptor_t {
306        self.desc
307    }
308}
309
310impl Drop for TensorDescriptor {
311    fn drop(&mut self) {
312        if let Ok(c) = cudnn() {
313            if let Ok(cu) = c.cudnn_destroy_tensor_descriptor() {
314                let _ = unsafe { cu(self.desc) };
315            }
316        }
317    }
318}
319
320/// Activation function kind.
321#[derive(Copy, Clone, Debug, Eq, PartialEq)]
322pub enum ActivationMode {
323    /// Rectified linear: `max(0, x)`.
324    Relu,
325    /// Logistic sigmoid: `1 / (1 + exp(-x))`.
326    Sigmoid,
327    /// Hyperbolic tangent.
328    Tanh,
329    /// Clipped ReLU: `min(max(0, x), ceiling)`.
330    ClippedRelu,
331    /// Exponential linear unit: `x` if `x > 0`, else `α · (exp(x) - 1)`.
332    Elu,
333    /// Pass-through (no activation applied).
334    Identity,
335    /// Swish / SiLU: `x · sigmoid(x)`.
336    Swish,
337}
338
339impl ActivationMode {
340    fn raw(self) -> cudnnActivationMode_t {
341        match self {
342            ActivationMode::Relu => cudnnActivationMode_t::Relu,
343            ActivationMode::Sigmoid => cudnnActivationMode_t::Sigmoid,
344            ActivationMode::Tanh => cudnnActivationMode_t::Tanh,
345            ActivationMode::ClippedRelu => cudnnActivationMode_t::ClippedRelu,
346            ActivationMode::Elu => cudnnActivationMode_t::Elu,
347            ActivationMode::Identity => cudnnActivationMode_t::Identity,
348            ActivationMode::Swish => cudnnActivationMode_t::Swish,
349        }
350    }
351}
352
353/// An activation descriptor.
354pub struct ActivationDescriptor {
355    desc: cudnnActivationDescriptor_t,
356}
357
358unsafe impl Send for ActivationDescriptor {}
359
360impl core::fmt::Debug for ActivationDescriptor {
361    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
362        f.debug_struct("ActivationDescriptor")
363            .field("desc", &self.desc)
364            .finish_non_exhaustive()
365    }
366}
367
368impl ActivationDescriptor {
369    /// Create a descriptor for `mode`. `coef` is only used by ClippedReLU
370    /// (ceiling) and ELU (α); pass `0.0` when irrelevant.
371    pub fn new(mode: ActivationMode, coef: f64) -> Result<Self> {
372        let c = cudnn()?;
373        let create = c.cudnn_create_activation_descriptor()?;
374        let set = c.cudnn_set_activation_descriptor()?;
375        let mut desc: cudnnActivationDescriptor_t = core::ptr::null_mut();
376        check(unsafe { create(&mut desc) })?;
377        let this = Self { desc };
378        check(unsafe {
379            set(
380                this.desc,
381                mode.raw(),
382                cudnnNanPropagation_t::PropagateNan,
383                coef,
384            )
385        })?;
386        Ok(this)
387    }
388
389    /// Raw descriptor.
390    #[inline]
391    pub fn as_raw(&self) -> cudnnActivationDescriptor_t {
392        self.desc
393    }
394}
395
396impl Drop for ActivationDescriptor {
397    fn drop(&mut self) {
398        if let Ok(c) = cudnn() {
399            if let Ok(cu) = c.cudnn_destroy_activation_descriptor() {
400                let _ = unsafe { cu(self.desc) };
401            }
402        }
403    }
404}
405
406/// Compute `y = alpha * activation(x) + beta * y` element-wise.
407///
408/// `x` and `y` may alias (in-place activation is legal).
409///
410/// # Example
411///
412/// ReLU on a `1 × 16 × 8 × 8` NCHW tensor.
413///
414/// ```no_run
415/// use baracuda_driver::{Context, Device, DeviceBuffer};
416/// use baracuda_cudnn::{
417///     activation_forward, ActivationDescriptor, ActivationMode,
418///     DType, Handle, TensorDescriptor, TensorFormat,
419/// };
420///
421/// # fn demo() -> Result<(), Box<dyn std::error::Error>> {
422/// let ctx = Context::new(&Device::get(0)?)?;
423/// let cudnn = Handle::new()?;
424///
425/// let (n, c, h, w) = (1, 16, 8, 8);
426/// let tdesc = TensorDescriptor::new_4d(TensorFormat::Nchw, DType::F32, n, c, h, w)?;
427/// let act = ActivationDescriptor::new(ActivationMode::Relu, 0.0)?;
428///
429/// let x: DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, (n*c*h*w) as usize)?;
430/// let mut y: DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, (n*c*h*w) as usize)?;
431///
432/// activation_forward(&cudnn, &act, 1.0, &tdesc, &x, 0.0, &tdesc, &mut y)?;
433/// # Ok(()) }
434/// ```
435#[allow(clippy::too_many_arguments)]
436pub fn activation_forward<T: DeviceRepr>(
437    handle: &Handle,
438    activation: &ActivationDescriptor,
439    alpha: f32,
440    x_desc: &TensorDescriptor,
441    x: &DeviceBuffer<T>,
442    beta: f32,
443    y_desc: &TensorDescriptor,
444    y: &mut DeviceBuffer<T>,
445) -> Result<()> {
446    let c = cudnn()?;
447    let cu = c.cudnn_activation_forward()?;
448    check(unsafe {
449        cu(
450            handle.handle,
451            activation.desc,
452            &alpha as *const f32 as *const core::ffi::c_void,
453            x_desc.desc,
454            x.as_raw().0 as *const core::ffi::c_void,
455            &beta as *const f32 as *const core::ffi::c_void,
456            y_desc.desc,
457            y.as_raw().0 as *mut core::ffi::c_void,
458        )
459    })
460}
461
462// ---- convolution ---------------------------------------------------------
463
464/// `N × C × H × W` 4-D filter.
465pub struct FilterDescriptor {
466    desc: cudnnFilterDescriptor_t,
467}
468
469unsafe impl Send for FilterDescriptor {}
470
471impl core::fmt::Debug for FilterDescriptor {
472    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
473        f.debug_struct("FilterDescriptor")
474            .field("desc", &self.desc)
475            .finish_non_exhaustive()
476    }
477}
478
479impl FilterDescriptor {
480    /// Describe a 4-D filter (convolution weight): K output channels, C input
481    /// channels, H filter height, W filter width.
482    pub fn new_4d(
483        format: TensorFormat,
484        dtype: DType,
485        k: i32,
486        c: i32,
487        h: i32,
488        w: i32,
489    ) -> Result<Self> {
490        let cu = cudnn()?;
491        let create = cu.cudnn_create_filter_descriptor()?;
492        let set = cu.cudnn_set_filter_4d_descriptor()?;
493        let mut desc: cudnnFilterDescriptor_t = core::ptr::null_mut();
494        check(unsafe { create(&mut desc) })?;
495        let this = Self { desc };
496        check(unsafe { set(this.desc, dtype.raw(), format.raw(), k, c, h, w) })?;
497        Ok(this)
498    }
499
500    /// Raw descriptor.
501    #[inline]
502    pub fn as_raw(&self) -> cudnnFilterDescriptor_t {
503        self.desc
504    }
505}
506
507impl Drop for FilterDescriptor {
508    fn drop(&mut self) {
509        if let Ok(c) = cudnn() {
510            if let Ok(cu) = c.cudnn_destroy_filter_descriptor() {
511                let _ = unsafe { cu(self.desc) };
512            }
513        }
514    }
515}
516
517/// Convolution mathematical mode.
518#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
519pub enum ConvMode {
520    /// True convolution (flips the filter).
521    Convolution,
522    /// Cross-correlation — what ML frameworks mean by "convolution". **Default.**
523    #[default]
524    CrossCorrelation,
525}
526
527impl ConvMode {
528    fn raw(self) -> cudnnConvolutionMode_t {
529        match self {
530            ConvMode::Convolution => cudnnConvolutionMode_t::Convolution,
531            ConvMode::CrossCorrelation => cudnnConvolutionMode_t::CrossCorrelation,
532        }
533    }
534}
535
536/// Forward-convolution algorithm selector. `Gemm` is the most broadly
537/// supported; `ImplicitPrecompGemm` / `Winograd` are faster where
538/// applicable.
539#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
540pub enum FwdAlgo {
541    /// Implicit GEMM — no extra workspace, broad compatibility. **Default.**
542    #[default]
543    ImplicitGemm,
544    /// Implicit GEMM with pre-computed indexing tables for faster lookup.
545    ImplicitPrecompGemm,
546    /// Explicit im2col + GEMM; widest dtype / shape coverage.
547    Gemm,
548    /// Direct convolution; small kernel + low-batch sweet spot.
549    Direct,
550    /// FFT-based convolution; favors large kernels.
551    Fft,
552    /// Tiled FFT; lower workspace than `Fft` at some perf cost.
553    FftTiling,
554    /// Winograd small-kernel fast algorithm.
555    Winograd,
556    /// Non-fused Winograd; trades workspace for perf.
557    WinogradNonfused,
558}
559
560impl FwdAlgo {
561    fn raw(self) -> cudnnConvolutionFwdAlgo_t {
562        match self {
563            FwdAlgo::ImplicitGemm => cudnnConvolutionFwdAlgo_t::ImplicitGemm,
564            FwdAlgo::ImplicitPrecompGemm => cudnnConvolutionFwdAlgo_t::ImplicitPrecompGemm,
565            FwdAlgo::Gemm => cudnnConvolutionFwdAlgo_t::Gemm,
566            FwdAlgo::Direct => cudnnConvolutionFwdAlgo_t::Direct,
567            FwdAlgo::Fft => cudnnConvolutionFwdAlgo_t::Fft,
568            FwdAlgo::FftTiling => cudnnConvolutionFwdAlgo_t::FftTiling,
569            FwdAlgo::Winograd => cudnnConvolutionFwdAlgo_t::Winograd,
570            FwdAlgo::WinogradNonfused => cudnnConvolutionFwdAlgo_t::WinogradNonfused,
571        }
572    }
573}
574
575/// Convolution descriptor: padding, stride, dilation, and compute dtype.
576pub struct ConvolutionDescriptor {
577    desc: cudnnConvolutionDescriptor_t,
578}
579
580unsafe impl Send for ConvolutionDescriptor {}
581
582impl core::fmt::Debug for ConvolutionDescriptor {
583    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
584        f.debug_struct("ConvolutionDescriptor")
585            .field("desc", &self.desc)
586            .finish_non_exhaustive()
587    }
588}
589
590impl ConvolutionDescriptor {
591    /// 2-D convolution descriptor.
592    /// - `pad_h`/`pad_w`: zero-padding on each side of the H/W axis.
593    /// - `stride_h`/`stride_w`: per-axis stride.
594    /// - `dilation_h`/`dilation_w`: per-axis dilation (1 = standard).
595    /// - `mode`: [`ConvMode::CrossCorrelation`] for ML; [`ConvMode::Convolution`] for true math convolution.
596    /// - `compute`: accumulation dtype (pass [`DType::F32`] even for mixed-precision FP16 input/output).
597    #[allow(clippy::too_many_arguments)]
598    pub fn new_2d(
599        pad_h: i32,
600        pad_w: i32,
601        stride_h: i32,
602        stride_w: i32,
603        dilation_h: i32,
604        dilation_w: i32,
605        mode: ConvMode,
606        compute: DType,
607    ) -> Result<Self> {
608        let cu = cudnn()?;
609        let create = cu.cudnn_create_convolution_descriptor()?;
610        let set = cu.cudnn_set_convolution_2d_descriptor()?;
611        let mut desc: cudnnConvolutionDescriptor_t = core::ptr::null_mut();
612        check(unsafe { create(&mut desc) })?;
613        let this = Self { desc };
614        check(unsafe {
615            set(
616                this.desc,
617                pad_h,
618                pad_w,
619                stride_h,
620                stride_w,
621                dilation_h,
622                dilation_w,
623                mode.raw(),
624                compute.raw(),
625            )
626        })?;
627        Ok(this)
628    }
629
630    /// Compute the `N × C × H × W` shape this convolution would produce given
631    /// the input + filter descriptors.
632    pub fn output_dim_2d(
633        &self,
634        input: &TensorDescriptor,
635        filter: &FilterDescriptor,
636    ) -> Result<(i32, i32, i32, i32)> {
637        let cu = cudnn()?;
638        let q = cu.cudnn_get_convolution_2d_forward_output_dim()?;
639        let mut n: core::ffi::c_int = 0;
640        let mut c: core::ffi::c_int = 0;
641        let mut h: core::ffi::c_int = 0;
642        let mut w: core::ffi::c_int = 0;
643        check(unsafe {
644            q(
645                self.desc,
646                input.desc,
647                filter.desc,
648                &mut n,
649                &mut c,
650                &mut h,
651                &mut w,
652            )
653        })?;
654        Ok((n, c, h, w))
655    }
656
657    /// Set the group count for grouped convolution. The default is 1
658    /// (regular convolution); pass `g > 1` for depthwise / grouped
659    /// variants. Filter shape must match: input C divides g, filter C
660    /// = input C / g.
661    pub fn set_group_count(&self, group_count: i32) -> Result<()> {
662        let cu = cudnn()?;
663        let f = cu.cudnn_set_convolution_group_count()?;
664        check(unsafe { f(self.desc, group_count) })
665    }
666
667    /// Read back the convolution group count.
668    pub fn group_count(&self) -> Result<i32> {
669        let cu = cudnn()?;
670        let f = cu.cudnn_get_convolution_group_count()?;
671        let mut g: core::ffi::c_int = 0;
672        check(unsafe { f(self.desc, &mut g) })?;
673        Ok(g)
674    }
675
676    /// Pick the math type cuDNN uses for this convolution — controls
677    /// whether tensor cores are eligible.
678    pub fn set_math_type(&self, math: MathType) -> Result<()> {
679        let cu = cudnn()?;
680        let f = cu.cudnn_set_convolution_math_type()?;
681        check(unsafe { f(self.desc, math.raw()) })
682    }
683
684    /// Read back the convolution math type.
685    pub fn math_type(&self) -> Result<MathType> {
686        let cu = cudnn()?;
687        let f = cu.cudnn_get_convolution_math_type()?;
688        let mut m = cudnnMathType_t::DefaultMath;
689        check(unsafe { f(self.desc, &mut m) })?;
690        Ok(MathType::from_raw(m))
691    }
692
693    /// Set the filter / bias reorder type for INT8 quantized inference.
694    pub fn set_reorder_type(&self, reorder: ReorderType) -> Result<()> {
695        let cu = cudnn()?;
696        let f = cu.cudnn_set_convolution_reorder_type()?;
697        check(unsafe { f(self.desc, reorder.raw()) })
698    }
699
700    /// Read back the reorder type.
701    pub fn reorder_type(&self) -> Result<ReorderType> {
702        let cu = cudnn()?;
703        let f = cu.cudnn_get_convolution_reorder_type()?;
704        let mut r = cudnnReorderType_t::DefaultReorder;
705        check(unsafe { f(self.desc, &mut r) })?;
706        Ok(ReorderType::from_raw(r))
707    }
708
709    /// Raw descriptor.
710    #[inline]
711    pub fn as_raw(&self) -> cudnnConvolutionDescriptor_t {
712        self.desc
713    }
714}
715
716impl Drop for ConvolutionDescriptor {
717    fn drop(&mut self) {
718        if let Ok(c) = cudnn() {
719            if let Ok(cu) = c.cudnn_destroy_convolution_descriptor() {
720                let _ = unsafe { cu(self.desc) };
721            }
722        }
723    }
724}
725
726/// Query the minimum workspace (bytes) required to run `algo` with the given
727/// tensor / filter / conv descriptors.
728pub fn convolution_forward_workspace_size(
729    handle: &Handle,
730    x: &TensorDescriptor,
731    w: &FilterDescriptor,
732    conv: &ConvolutionDescriptor,
733    y: &TensorDescriptor,
734    algo: FwdAlgo,
735) -> Result<usize> {
736    let cu = cudnn()?;
737    let q = cu.cudnn_get_convolution_forward_workspace_size()?;
738    let mut size: usize = 0;
739    check(unsafe {
740        q(
741            handle.handle,
742            x.desc,
743            w.desc,
744            conv.desc,
745            y.desc,
746            algo.raw(),
747            &mut size,
748        )
749    })?;
750    Ok(size)
751}
752
753/// `Y = alpha * conv(X, W) + beta * Y` (forward pass).
754///
755/// `workspace` must be at least the size returned by
756/// [`convolution_forward_workspace_size`].
757///
758/// # Example
759///
760/// End-to-end "build descriptors → query workspace → run forward".
761/// The example uses `f32` and a 3×3 convolution with padding 1,
762/// stride 1, dilation 1, no groups.
763///
764/// ```no_run
765/// use baracuda_cudnn::{
766///     convolution_forward, convolution_forward_workspace_size,
767///     ConvMode, ConvolutionDescriptor, DType, FilterDescriptor, FwdAlgo,
768///     Handle, TensorDescriptor, TensorFormat,
769/// };
770/// use baracuda_driver::{Context, Device, DeviceBuffer};
771///
772/// # fn demo() -> Result<(), Box<dyn std::error::Error>> {
773/// let ctx   = Context::new(&Device::get(0)?)?;
774/// let cudnn = Handle::new()?;
775///
776/// // Shapes: NCHW 1×3×32×32 input, 16 output channels, 3×3 kernel, pad 1.
777/// let (n, c, h, w)   = (1, 3, 32, 32);
778/// let (k, kh, kw)    = (16, 3, 3);
779/// let (pad_h, pad_w) = (1, 1);
780/// let (str_h, str_w) = (1, 1);
781/// let (dil_h, dil_w) = (1, 1);
782/// let (out_h, out_w) = (h, w);   // same-size output for pad=1, k=3, str=1
783///
784/// // Note the argument order: TensorDescriptor::new_4d takes
785/// // (format, dtype, n, c, h, w); FilterDescriptor::new_4d takes
786/// // (format, dtype, k, c, kh, kw).
787/// let x_desc = TensorDescriptor::new_4d(TensorFormat::Nchw, DType::F32, n, c, h, w)?;
788/// let w_desc = FilterDescriptor::new_4d(TensorFormat::Nchw, DType::F32, k, c, kh, kw)?;
789/// let y_desc = TensorDescriptor::new_4d(TensorFormat::Nchw, DType::F32, n, k, out_h, out_w)?;
790/// let conv = ConvolutionDescriptor::new_2d(
791///     pad_h, pad_w, str_h, str_w, dil_h, dil_w,
792///     ConvMode::CrossCorrelation, DType::F32,
793/// )?;
794/// // For grouped conv, set the group count after creation:
795/// // conv.set_group_count(groups)?;
796///
797/// // Pick an algorithm. ImplicitGemm is a safe default; for perf, use
798/// // `find_convolution_forward_algorithm` to benchmark on your shapes.
799/// let algo = FwdAlgo::ImplicitGemm;
800///
801/// // Workspace size depends on (descs, algo).
802/// let ws_bytes = convolution_forward_workspace_size(
803///     &cudnn, &x_desc, &w_desc, &conv, &y_desc, algo,
804/// )?;
805///
806/// // Allocate input / weight / output / workspace on the device.
807/// let x_buf:   DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, (n*c*h*w) as usize)?;
808/// let w_buf:   DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, (k*c*kh*kw) as usize)?;
809/// let mut y_buf: DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, (n*k*out_h*out_w) as usize)?;
810/// let mut ws: DeviceBuffer<u8> = DeviceBuffer::zeros(&ctx, ws_bytes.max(1))?;
811///
812/// convolution_forward(
813///     &cudnn,
814///     1.0, &x_desc, &x_buf,
815///          &w_desc, &w_buf,
816///     &conv, algo,
817///     &mut ws,
818///     0.0, &y_desc, &mut y_buf,
819/// )?;
820/// # Ok(()) }
821/// ```
822#[allow(clippy::too_many_arguments)]
823pub fn convolution_forward<T: DeviceRepr>(
824    handle: &Handle,
825    alpha: f32,
826    x_desc: &TensorDescriptor,
827    x: &DeviceBuffer<T>,
828    w_desc: &FilterDescriptor,
829    w: &DeviceBuffer<T>,
830    conv: &ConvolutionDescriptor,
831    algo: FwdAlgo,
832    workspace: &mut DeviceBuffer<u8>,
833    beta: f32,
834    y_desc: &TensorDescriptor,
835    y: &mut DeviceBuffer<T>,
836) -> Result<()> {
837    let c = cudnn()?;
838    let cu = c.cudnn_convolution_forward()?;
839    check(unsafe {
840        cu(
841            handle.handle,
842            &alpha as *const f32 as *const core::ffi::c_void,
843            x_desc.desc,
844            x.as_raw().0 as *const core::ffi::c_void,
845            w_desc.desc,
846            w.as_raw().0 as *const core::ffi::c_void,
847            conv.desc,
848            algo.raw(),
849            workspace.as_raw().0 as *mut core::ffi::c_void,
850            workspace.byte_size(),
851            &beta as *const f32 as *const core::ffi::c_void,
852            y_desc.desc,
853            y.as_raw().0 as *mut core::ffi::c_void,
854        )
855    })
856}
857
858// ---- convolution backward ------------------------------------------------
859
860/// Backward-data convolution algorithm selector.
861#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
862pub enum BwdDataAlgo {
863    /// Non-deterministic, broad-coverage backward-data algorithm. **Default.**
864    #[default]
865    Algo0,
866    /// Deterministic backward-data algorithm.
867    Algo1,
868    /// FFT-based backward-data algorithm.
869    Fft,
870    /// Tiled FFT — lower workspace than `Fft`.
871    FftTiling,
872    /// Winograd small-kernel fast algorithm.
873    Winograd,
874    /// Non-fused Winograd; trades workspace for perf.
875    WinogradNonfused,
876}
877
878impl BwdDataAlgo {
879    fn raw(self) -> cudnnConvolutionBwdDataAlgo_t {
880        match self {
881            Self::Algo0 => cudnnConvolutionBwdDataAlgo_t::Algo0,
882            Self::Algo1 => cudnnConvolutionBwdDataAlgo_t::Algo1,
883            Self::Fft => cudnnConvolutionBwdDataAlgo_t::Fft,
884            Self::FftTiling => cudnnConvolutionBwdDataAlgo_t::FftTiling,
885            Self::Winograd => cudnnConvolutionBwdDataAlgo_t::Winograd,
886            Self::WinogradNonfused => cudnnConvolutionBwdDataAlgo_t::WinogradNonfused,
887        }
888    }
889}
890
891/// Backward-filter convolution algorithm selector.
892#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
893pub enum BwdFilterAlgo {
894    /// Non-deterministic, broad-coverage backward-filter algorithm. **Default.**
895    #[default]
896    Algo0,
897    /// Deterministic backward-filter algorithm.
898    Algo1,
899    /// FFT-based backward-filter algorithm.
900    Fft,
901    /// Alternative non-deterministic algorithm (typically faster on large filters).
902    Algo3,
903    /// Winograd small-kernel fast algorithm.
904    Winograd,
905    /// Non-fused Winograd; trades workspace for perf.
906    WinogradNonfused,
907    /// Tiled FFT — lower workspace than `Fft`.
908    FftTiling,
909}
910
911impl BwdFilterAlgo {
912    fn raw(self) -> cudnnConvolutionBwdFilterAlgo_t {
913        match self {
914            Self::Algo0 => cudnnConvolutionBwdFilterAlgo_t::Algo0,
915            Self::Algo1 => cudnnConvolutionBwdFilterAlgo_t::Algo1,
916            Self::Fft => cudnnConvolutionBwdFilterAlgo_t::Fft,
917            Self::Algo3 => cudnnConvolutionBwdFilterAlgo_t::Algo3,
918            Self::Winograd => cudnnConvolutionBwdFilterAlgo_t::Winograd,
919            Self::WinogradNonfused => cudnnConvolutionBwdFilterAlgo_t::WinogradNonfused,
920            Self::FftTiling => cudnnConvolutionBwdFilterAlgo_t::FftTiling,
921        }
922    }
923}
924
925/// Workspace bytes required to run [`convolution_backward_data`] with the
926/// given `algo` and descriptors.
927pub fn convolution_backward_data_workspace_size(
928    handle: &Handle,
929    w: &FilterDescriptor,
930    dy: &TensorDescriptor,
931    conv: &ConvolutionDescriptor,
932    dx: &TensorDescriptor,
933    algo: BwdDataAlgo,
934) -> Result<usize> {
935    let cu = cudnn()?;
936    let q = cu.cudnn_get_convolution_backward_data_workspace_size()?;
937    let mut size = 0usize;
938    check(unsafe {
939        q(
940            handle.handle,
941            w.desc,
942            dy.desc,
943            conv.desc,
944            dx.desc,
945            algo.raw(),
946            &mut size,
947        )
948    })?;
949    Ok(size)
950}
951
952/// Workspace bytes required to run [`convolution_backward_filter`] with
953/// the given `algo` and descriptors.
954pub fn convolution_backward_filter_workspace_size(
955    handle: &Handle,
956    x: &TensorDescriptor,
957    dy: &TensorDescriptor,
958    conv: &ConvolutionDescriptor,
959    dw: &FilterDescriptor,
960    algo: BwdFilterAlgo,
961) -> Result<usize> {
962    let cu = cudnn()?;
963    let q = cu.cudnn_get_convolution_backward_filter_workspace_size()?;
964    let mut size = 0usize;
965    check(unsafe {
966        q(
967            handle.handle,
968            x.desc,
969            dy.desc,
970            conv.desc,
971            dw.desc,
972            algo.raw(),
973            &mut size,
974        )
975    })?;
976    Ok(size)
977}
978
979/// `dX = alpha * conv_bwd_data(W, dY) + beta * dX`.
980#[allow(clippy::too_many_arguments)]
981pub fn convolution_backward_data<T: DeviceRepr>(
982    handle: &Handle,
983    alpha: f32,
984    w_desc: &FilterDescriptor,
985    w: &DeviceBuffer<T>,
986    dy_desc: &TensorDescriptor,
987    dy: &DeviceBuffer<T>,
988    conv: &ConvolutionDescriptor,
989    algo: BwdDataAlgo,
990    workspace: &mut DeviceBuffer<u8>,
991    beta: f32,
992    dx_desc: &TensorDescriptor,
993    dx: &mut DeviceBuffer<T>,
994) -> Result<()> {
995    let c = cudnn()?;
996    let cu = c.cudnn_convolution_backward_data()?;
997    check(unsafe {
998        cu(
999            handle.handle,
1000            &alpha as *const f32 as *const core::ffi::c_void,
1001            w_desc.desc,
1002            w.as_raw().0 as *const core::ffi::c_void,
1003            dy_desc.desc,
1004            dy.as_raw().0 as *const core::ffi::c_void,
1005            conv.desc,
1006            algo.raw(),
1007            workspace.as_raw().0 as *mut core::ffi::c_void,
1008            workspace.byte_size(),
1009            &beta as *const f32 as *const core::ffi::c_void,
1010            dx_desc.desc,
1011            dx.as_raw().0 as *mut core::ffi::c_void,
1012        )
1013    })
1014}
1015
1016/// `dW = alpha * conv_bwd_filter(X, dY) + beta * dW`.
1017#[allow(clippy::too_many_arguments)]
1018pub fn convolution_backward_filter<T: DeviceRepr>(
1019    handle: &Handle,
1020    alpha: f32,
1021    x_desc: &TensorDescriptor,
1022    x: &DeviceBuffer<T>,
1023    dy_desc: &TensorDescriptor,
1024    dy: &DeviceBuffer<T>,
1025    conv: &ConvolutionDescriptor,
1026    algo: BwdFilterAlgo,
1027    workspace: &mut DeviceBuffer<u8>,
1028    beta: f32,
1029    dw_desc: &FilterDescriptor,
1030    dw: &mut DeviceBuffer<T>,
1031) -> Result<()> {
1032    let c = cudnn()?;
1033    let cu = c.cudnn_convolution_backward_filter()?;
1034    check(unsafe {
1035        cu(
1036            handle.handle,
1037            &alpha as *const f32 as *const core::ffi::c_void,
1038            x_desc.desc,
1039            x.as_raw().0 as *const core::ffi::c_void,
1040            dy_desc.desc,
1041            dy.as_raw().0 as *const core::ffi::c_void,
1042            conv.desc,
1043            algo.raw(),
1044            workspace.as_raw().0 as *mut core::ffi::c_void,
1045            workspace.byte_size(),
1046            &beta as *const f32 as *const core::ffi::c_void,
1047            dw_desc.desc,
1048            dw.as_raw().0 as *mut core::ffi::c_void,
1049        )
1050    })
1051}
1052
1053/// Add the bias gradient: sum over spatial dims of `dY`.
1054pub fn convolution_backward_bias<T: DeviceRepr>(
1055    handle: &Handle,
1056    alpha: f32,
1057    dy_desc: &TensorDescriptor,
1058    dy: &DeviceBuffer<T>,
1059    beta: f32,
1060    db_desc: &TensorDescriptor,
1061    db: &mut DeviceBuffer<T>,
1062) -> Result<()> {
1063    let c = cudnn()?;
1064    let cu = c.cudnn_convolution_backward_bias()?;
1065    check(unsafe {
1066        cu(
1067            handle.handle,
1068            &alpha as *const f32 as *const core::ffi::c_void,
1069            dy_desc.desc,
1070            dy.as_raw().0 as *const core::ffi::c_void,
1071            &beta as *const f32 as *const core::ffi::c_void,
1072            db_desc.desc,
1073            db.as_raw().0 as *mut core::ffi::c_void,
1074        )
1075    })
1076}
1077
1078// ---- pooling --------------------------------------------------------------
1079
1080/// Pooling reduction kind.
1081#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
1082pub enum PoolingMode {
1083    /// Max-pool; non-deterministic when ties occur. **Default.**
1084    #[default]
1085    Max,
1086    /// Average-pool counting padded cells in the denominator.
1087    AverageCountIncludePadding,
1088    /// Average-pool excluding padded cells from the denominator.
1089    AverageCountExcludePadding,
1090    /// Max-pool with deterministic tie-break (lower throughput).
1091    MaxDeterministic,
1092}
1093
1094impl PoolingMode {
1095    fn raw(self) -> cudnnPoolingMode_t {
1096        match self {
1097            Self::Max => cudnnPoolingMode_t::Max,
1098            Self::AverageCountIncludePadding => cudnnPoolingMode_t::AverageCountIncludePadding,
1099            Self::AverageCountExcludePadding => cudnnPoolingMode_t::AverageCountExcludePadding,
1100            Self::MaxDeterministic => cudnnPoolingMode_t::MaxDeterministic,
1101        }
1102    }
1103}
1104
1105/// A pooling descriptor: pooling mode, window extent, padding, and stride.
1106pub struct PoolingDescriptor {
1107    desc: cudnnPoolingDescriptor_t,
1108}
1109
1110unsafe impl Send for PoolingDescriptor {}
1111
1112impl core::fmt::Debug for PoolingDescriptor {
1113    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1114        f.debug_struct("PoolingDescriptor")
1115            .field("desc", &self.desc)
1116            .finish_non_exhaustive()
1117    }
1118}
1119
1120impl PoolingDescriptor {
1121    /// 2-D pooling descriptor with explicit window / padding / stride.
1122    /// NaN propagation defaults to `PropagateNan`.
1123    #[allow(clippy::too_many_arguments)]
1124    pub fn new_2d(
1125        mode: PoolingMode,
1126        window_h: i32,
1127        window_w: i32,
1128        pad_h: i32,
1129        pad_w: i32,
1130        stride_h: i32,
1131        stride_w: i32,
1132    ) -> Result<Self> {
1133        let cu = cudnn()?;
1134        let create = cu.cudnn_create_pooling_descriptor()?;
1135        let set = cu.cudnn_set_pooling_2d_descriptor()?;
1136        let mut desc: cudnnPoolingDescriptor_t = core::ptr::null_mut();
1137        check(unsafe { create(&mut desc) })?;
1138        let this = Self { desc };
1139        check(unsafe {
1140            set(
1141                this.desc,
1142                mode.raw(),
1143                cudnnNanPropagation_t::PropagateNan,
1144                window_h,
1145                window_w,
1146                pad_h,
1147                pad_w,
1148                stride_h,
1149                stride_w,
1150            )
1151        })?;
1152        Ok(this)
1153    }
1154
1155    /// Raw descriptor.
1156    #[inline]
1157    pub fn as_raw(&self) -> cudnnPoolingDescriptor_t {
1158        self.desc
1159    }
1160}
1161
1162impl Drop for PoolingDescriptor {
1163    fn drop(&mut self) {
1164        if let Ok(c) = cudnn() {
1165            if let Ok(cu) = c.cudnn_destroy_pooling_descriptor() {
1166                let _ = unsafe { cu(self.desc) };
1167            }
1168        }
1169    }
1170}
1171
1172/// `Y = alpha * pool(X) + beta * Y` (forward pass).
1173///
1174/// # Example
1175///
1176/// 2×2 max-pool with stride 2 on a `1 × 16 × 8 × 8` input → `1 × 16 × 4 × 4`
1177/// output.
1178///
1179/// ```no_run
1180/// use baracuda_driver::{Context, Device, DeviceBuffer};
1181/// use baracuda_cudnn::{
1182///     pooling_forward, DType, Handle, PoolingDescriptor, PoolingMode,
1183///     TensorDescriptor, TensorFormat,
1184/// };
1185///
1186/// # fn demo() -> Result<(), Box<dyn std::error::Error>> {
1187/// let ctx = Context::new(&Device::get(0)?)?;
1188/// let cudnn = Handle::new()?;
1189///
1190/// let (n, c) = (1, 16);
1191/// let (in_h, in_w) = (8, 8);
1192/// let (out_h, out_w) = (4, 4);
1193///
1194/// let x_desc = TensorDescriptor::new_4d(TensorFormat::Nchw, DType::F32, n, c, in_h, in_w)?;
1195/// let y_desc = TensorDescriptor::new_4d(TensorFormat::Nchw, DType::F32, n, c, out_h, out_w)?;
1196/// // MaxPool2d: window=2x2, pad=0, stride=2.
1197/// let pool = PoolingDescriptor::new_2d(PoolingMode::Max, 2, 2, 0, 0, 2, 2)?;
1198///
1199/// let x: DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, (n*c*in_h*in_w) as usize)?;
1200/// let mut y: DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, (n*c*out_h*out_w) as usize)?;
1201///
1202/// pooling_forward(&cudnn, &pool, 1.0, &x_desc, &x, 0.0, &y_desc, &mut y)?;
1203/// # Ok(()) }
1204/// ```
1205#[allow(clippy::too_many_arguments)]
1206pub fn pooling_forward<T: DeviceRepr>(
1207    handle: &Handle,
1208    pool: &PoolingDescriptor,
1209    alpha: f32,
1210    x_desc: &TensorDescriptor,
1211    x: &DeviceBuffer<T>,
1212    beta: f32,
1213    y_desc: &TensorDescriptor,
1214    y: &mut DeviceBuffer<T>,
1215) -> Result<()> {
1216    let c = cudnn()?;
1217    let cu = c.cudnn_pooling_forward()?;
1218    check(unsafe {
1219        cu(
1220            handle.handle,
1221            pool.desc,
1222            &alpha as *const f32 as *const core::ffi::c_void,
1223            x_desc.desc,
1224            x.as_raw().0 as *const core::ffi::c_void,
1225            &beta as *const f32 as *const core::ffi::c_void,
1226            y_desc.desc,
1227            y.as_raw().0 as *mut core::ffi::c_void,
1228        )
1229    })
1230}
1231
1232/// `dX = alpha * pool_backward(Y, dY, X) + beta * dX`.
1233#[allow(clippy::too_many_arguments)]
1234pub fn pooling_backward<T: DeviceRepr>(
1235    handle: &Handle,
1236    pool: &PoolingDescriptor,
1237    alpha: f32,
1238    y_desc: &TensorDescriptor,
1239    y: &DeviceBuffer<T>,
1240    dy_desc: &TensorDescriptor,
1241    dy: &DeviceBuffer<T>,
1242    x_desc: &TensorDescriptor,
1243    x: &DeviceBuffer<T>,
1244    beta: f32,
1245    dx_desc: &TensorDescriptor,
1246    dx: &mut DeviceBuffer<T>,
1247) -> Result<()> {
1248    let c = cudnn()?;
1249    let cu = c.cudnn_pooling_backward()?;
1250    check(unsafe {
1251        cu(
1252            handle.handle,
1253            pool.desc,
1254            &alpha as *const f32 as *const core::ffi::c_void,
1255            y_desc.desc,
1256            y.as_raw().0 as *const core::ffi::c_void,
1257            dy_desc.desc,
1258            dy.as_raw().0 as *const core::ffi::c_void,
1259            x_desc.desc,
1260            x.as_raw().0 as *const core::ffi::c_void,
1261            &beta as *const f32 as *const core::ffi::c_void,
1262            dx_desc.desc,
1263            dx.as_raw().0 as *mut core::ffi::c_void,
1264        )
1265    })
1266}
1267
1268// ---- softmax --------------------------------------------------------------
1269
1270/// Numerical softmax algorithm.
1271#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
1272pub enum SoftmaxAlgo {
1273    /// Direct `exp(x) / sum(exp(x))` — fast but susceptible to overflow.
1274    Fast,
1275    /// Max-shifted softmax for numerical stability. **Default.**
1276    #[default]
1277    Accurate,
1278    /// Log-softmax: `x - logsumexp(x)`.
1279    Log,
1280}
1281
1282impl SoftmaxAlgo {
1283    fn raw(self) -> cudnnSoftmaxAlgorithm_t {
1284        match self {
1285            Self::Fast => cudnnSoftmaxAlgorithm_t::Fast,
1286            Self::Accurate => cudnnSoftmaxAlgorithm_t::Accurate,
1287            Self::Log => cudnnSoftmaxAlgorithm_t::Log,
1288        }
1289    }
1290}
1291
1292/// Axis the softmax normalizes over.
1293#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
1294pub enum SoftmaxMode {
1295    /// Softmax over all `C × H × W` per sample (one normalizer per batch row).
1296    Instance,
1297    /// Softmax over `C` per spatial location (one normalizer per `(N, H, W)`). **Default.**
1298    #[default]
1299    Channel,
1300}
1301
1302impl SoftmaxMode {
1303    fn raw(self) -> cudnnSoftmaxMode_t {
1304        match self {
1305            Self::Instance => cudnnSoftmaxMode_t::Instance,
1306            Self::Channel => cudnnSoftmaxMode_t::Channel,
1307        }
1308    }
1309}
1310
1311/// `Y = alpha * softmax(X, algo, mode) + beta * Y`.
1312#[allow(clippy::too_many_arguments)]
1313pub fn softmax_forward<T: DeviceRepr>(
1314    handle: &Handle,
1315    algo: SoftmaxAlgo,
1316    mode: SoftmaxMode,
1317    alpha: f32,
1318    x_desc: &TensorDescriptor,
1319    x: &DeviceBuffer<T>,
1320    beta: f32,
1321    y_desc: &TensorDescriptor,
1322    y: &mut DeviceBuffer<T>,
1323) -> Result<()> {
1324    let c = cudnn()?;
1325    let cu = c.cudnn_softmax_forward()?;
1326    check(unsafe {
1327        cu(
1328            handle.handle,
1329            algo.raw(),
1330            mode.raw(),
1331            &alpha as *const f32 as *const core::ffi::c_void,
1332            x_desc.desc,
1333            x.as_raw().0 as *const core::ffi::c_void,
1334            &beta as *const f32 as *const core::ffi::c_void,
1335            y_desc.desc,
1336            y.as_raw().0 as *mut core::ffi::c_void,
1337        )
1338    })
1339}
1340
1341/// `dX = alpha * softmax_backward(Y, dY) + beta * dX`.
1342#[allow(clippy::too_many_arguments)]
1343pub fn softmax_backward<T: DeviceRepr>(
1344    handle: &Handle,
1345    algo: SoftmaxAlgo,
1346    mode: SoftmaxMode,
1347    alpha: f32,
1348    y_desc: &TensorDescriptor,
1349    y: &DeviceBuffer<T>,
1350    dy_desc: &TensorDescriptor,
1351    dy: &DeviceBuffer<T>,
1352    beta: f32,
1353    dx_desc: &TensorDescriptor,
1354    dx: &mut DeviceBuffer<T>,
1355) -> Result<()> {
1356    let c = cudnn()?;
1357    let cu = c.cudnn_softmax_backward()?;
1358    check(unsafe {
1359        cu(
1360            handle.handle,
1361            algo.raw(),
1362            mode.raw(),
1363            &alpha as *const f32 as *const core::ffi::c_void,
1364            y_desc.desc,
1365            y.as_raw().0 as *const core::ffi::c_void,
1366            dy_desc.desc,
1367            dy.as_raw().0 as *const core::ffi::c_void,
1368            &beta as *const f32 as *const core::ffi::c_void,
1369            dx_desc.desc,
1370            dx.as_raw().0 as *mut core::ffi::c_void,
1371        )
1372    })
1373}
1374
1375// ---- batch normalization --------------------------------------------------
1376
1377/// Batch-normalization parameter sharing pattern.
1378#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
1379pub enum BatchNormMode {
1380    /// One scale/bias per `(C, H, W)` cell — used by fully-connected BN.
1381    PerActivation,
1382    /// One scale/bias per `C` channel (shared spatially). **Default.** Matches `nn.BatchNorm2d`.
1383    #[default]
1384    Spatial,
1385    /// Same as `Spatial` but uses the persistent-mode fast kernel where available.
1386    SpatialPersistent,
1387}
1388
1389impl BatchNormMode {
1390    fn raw(self) -> cudnnBatchNormMode_t {
1391        match self {
1392            Self::PerActivation => cudnnBatchNormMode_t::PerActivation,
1393            Self::Spatial => cudnnBatchNormMode_t::Spatial,
1394            Self::SpatialPersistent => cudnnBatchNormMode_t::SpatialPersistent,
1395        }
1396    }
1397}
1398
1399/// Training-time BN forward: updates running statistics and returns saved
1400/// `mean` / `inv_variance` for use by [`batch_normalization_backward`].
1401#[allow(clippy::too_many_arguments)]
1402pub fn batch_normalization_forward_training<T: DeviceRepr>(
1403    handle: &Handle,
1404    mode: BatchNormMode,
1405    alpha: f32,
1406    beta: f32,
1407    x_desc: &TensorDescriptor,
1408    x: &DeviceBuffer<T>,
1409    y_desc: &TensorDescriptor,
1410    y: &mut DeviceBuffer<T>,
1411    bn_smbv_desc: &TensorDescriptor,
1412    bn_scale: &DeviceBuffer<T>,
1413    bn_bias: &DeviceBuffer<T>,
1414    exponential_avg_factor: f64,
1415    running_mean: &mut DeviceBuffer<T>,
1416    running_variance: &mut DeviceBuffer<T>,
1417    epsilon: f64,
1418    saved_mean: &mut DeviceBuffer<T>,
1419    saved_inv_variance: &mut DeviceBuffer<T>,
1420) -> Result<()> {
1421    let c = cudnn()?;
1422    let cu = c.cudnn_batch_normalization_forward_training()?;
1423    check(unsafe {
1424        cu(
1425            handle.handle,
1426            mode.raw(),
1427            &alpha as *const f32 as *const core::ffi::c_void,
1428            &beta as *const f32 as *const core::ffi::c_void,
1429            x_desc.desc,
1430            x.as_raw().0 as *const core::ffi::c_void,
1431            y_desc.desc,
1432            y.as_raw().0 as *mut core::ffi::c_void,
1433            bn_smbv_desc.desc,
1434            bn_scale.as_raw().0 as *const core::ffi::c_void,
1435            bn_bias.as_raw().0 as *const core::ffi::c_void,
1436            exponential_avg_factor,
1437            running_mean.as_raw().0 as *mut core::ffi::c_void,
1438            running_variance.as_raw().0 as *mut core::ffi::c_void,
1439            epsilon,
1440            saved_mean.as_raw().0 as *mut core::ffi::c_void,
1441            saved_inv_variance.as_raw().0 as *mut core::ffi::c_void,
1442        )
1443    })
1444}
1445
1446/// BN backward — matched with [`batch_normalization_forward_training`].
1447#[allow(clippy::too_many_arguments)]
1448pub fn batch_normalization_backward<T: DeviceRepr>(
1449    handle: &Handle,
1450    mode: BatchNormMode,
1451    alpha_data_diff: f32,
1452    beta_data_diff: f32,
1453    alpha_param_diff: f32,
1454    beta_param_diff: f32,
1455    x_desc: &TensorDescriptor,
1456    x: &DeviceBuffer<T>,
1457    dy_desc: &TensorDescriptor,
1458    dy: &DeviceBuffer<T>,
1459    dx_desc: &TensorDescriptor,
1460    dx: &mut DeviceBuffer<T>,
1461    bn_scale_bias_diff_desc: &TensorDescriptor,
1462    bn_scale: &DeviceBuffer<T>,
1463    d_bn_scale: &mut DeviceBuffer<T>,
1464    d_bn_bias: &mut DeviceBuffer<T>,
1465    epsilon: f64,
1466    saved_mean: &DeviceBuffer<T>,
1467    saved_inv_variance: &DeviceBuffer<T>,
1468) -> Result<()> {
1469    let c = cudnn()?;
1470    let cu = c.cudnn_batch_normalization_backward()?;
1471    check(unsafe {
1472        cu(
1473            handle.handle,
1474            mode.raw(),
1475            &alpha_data_diff as *const f32 as *const core::ffi::c_void,
1476            &beta_data_diff as *const f32 as *const core::ffi::c_void,
1477            &alpha_param_diff as *const f32 as *const core::ffi::c_void,
1478            &beta_param_diff as *const f32 as *const core::ffi::c_void,
1479            x_desc.desc,
1480            x.as_raw().0 as *const core::ffi::c_void,
1481            dy_desc.desc,
1482            dy.as_raw().0 as *const core::ffi::c_void,
1483            dx_desc.desc,
1484            dx.as_raw().0 as *mut core::ffi::c_void,
1485            bn_scale_bias_diff_desc.desc,
1486            bn_scale.as_raw().0 as *const core::ffi::c_void,
1487            d_bn_scale.as_raw().0 as *mut core::ffi::c_void,
1488            d_bn_bias.as_raw().0 as *mut core::ffi::c_void,
1489            epsilon,
1490            saved_mean.as_raw().0 as *const core::ffi::c_void,
1491            saved_inv_variance.as_raw().0 as *const core::ffi::c_void,
1492        )
1493    })
1494}
1495
1496/// Inference-time BN forward: uses pre-computed running statistics (no
1497/// state update). Use after model training is complete.
1498///
1499/// # Example
1500///
1501/// Per-spatial BN on a `1 × 16 × 8 × 8` NCHW tensor with 16 channel scale /
1502/// bias / mean / variance vectors.
1503///
1504/// ```no_run
1505/// use baracuda_driver::{Context, Device, DeviceBuffer};
1506/// use baracuda_cudnn::{
1507///     batch_normalization_forward_inference, BatchNormMode, DType, Handle,
1508///     TensorDescriptor, TensorFormat,
1509/// };
1510///
1511/// # fn demo() -> Result<(), Box<dyn std::error::Error>> {
1512/// let ctx = Context::new(&Device::get(0)?)?;
1513/// let cudnn = Handle::new()?;
1514///
1515/// let (n, c, h, w) = (1, 16, 8, 8);
1516/// let xy_desc = TensorDescriptor::new_4d(TensorFormat::Nchw, DType::F32, n, c, h, w)?;
1517/// // BN parameter tensor is shape (1, C, 1, 1) for Spatial mode.
1518/// let bn_desc = TensorDescriptor::new_4d(TensorFormat::Nchw, DType::F32, 1, c, 1, 1)?;
1519///
1520/// let x: DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, (n*c*h*w) as usize)?;
1521/// let mut y: DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, (n*c*h*w) as usize)?;
1522/// let scale: DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, c as usize)?;
1523/// let bias:  DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, c as usize)?;
1524/// let mean:  DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, c as usize)?;
1525/// let var:   DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, c as usize)?;
1526///
1527/// batch_normalization_forward_inference(
1528///     &cudnn, BatchNormMode::Spatial,
1529///     1.0, 0.0,
1530///     &xy_desc, &x, &xy_desc, &mut y,
1531///     &bn_desc, &scale, &bias, &mean, &var,
1532///     1e-5,
1533/// )?;
1534/// # Ok(()) }
1535/// ```
1536#[allow(clippy::too_many_arguments)]
1537pub fn batch_normalization_forward_inference<T: DeviceRepr>(
1538    handle: &Handle,
1539    mode: BatchNormMode,
1540    alpha: f32,
1541    beta: f32,
1542    x_desc: &TensorDescriptor,
1543    x: &DeviceBuffer<T>,
1544    y_desc: &TensorDescriptor,
1545    y: &mut DeviceBuffer<T>,
1546    bn_smbv_desc: &TensorDescriptor,
1547    bn_scale: &DeviceBuffer<T>,
1548    bn_bias: &DeviceBuffer<T>,
1549    estimated_mean: &DeviceBuffer<T>,
1550    estimated_var: &DeviceBuffer<T>,
1551    epsilon: f64,
1552) -> Result<()> {
1553    let c = cudnn()?;
1554    let cu = c.cudnn_batch_normalization_forward_inference()?;
1555    check(unsafe {
1556        cu(
1557            handle.handle,
1558            mode.raw(),
1559            &alpha as *const f32 as *const core::ffi::c_void,
1560            &beta as *const f32 as *const core::ffi::c_void,
1561            x_desc.desc,
1562            x.as_raw().0 as *const core::ffi::c_void,
1563            y_desc.desc,
1564            y.as_raw().0 as *mut core::ffi::c_void,
1565            bn_smbv_desc.desc,
1566            bn_scale.as_raw().0 as *const core::ffi::c_void,
1567            bn_bias.as_raw().0 as *const core::ffi::c_void,
1568            estimated_mean.as_raw().0 as *const core::ffi::c_void,
1569            estimated_var.as_raw().0 as *const core::ffi::c_void,
1570            epsilon,
1571        )
1572    })
1573}
1574
1575// ---- dropout --------------------------------------------------------------
1576
1577/// A dropout descriptor: dropout probability + RNG state buffer.
1578pub struct DropoutDescriptor {
1579    desc: cudnnDropoutDescriptor_t,
1580}
1581
1582unsafe impl Send for DropoutDescriptor {}
1583
1584impl core::fmt::Debug for DropoutDescriptor {
1585    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1586        f.debug_struct("DropoutDescriptor")
1587            .field("desc", &self.desc)
1588            .finish_non_exhaustive()
1589    }
1590}
1591
1592impl DropoutDescriptor {
1593    /// Create a dropout descriptor with probability `dropout` ∈ \[0, 1\].
1594    ///
1595    /// `states` is a driver-owned buffer of at least
1596    /// [`dropout_states_size`] bytes, shared across many descriptors.
1597    pub fn new(
1598        handle: &Handle,
1599        dropout: f32,
1600        states: &mut DeviceBuffer<u8>,
1601        seed: u64,
1602    ) -> Result<Self> {
1603        let cu = cudnn()?;
1604        let create = cu.cudnn_create_dropout_descriptor()?;
1605        let set = cu.cudnn_set_dropout_descriptor()?;
1606        let mut desc: cudnnDropoutDescriptor_t = core::ptr::null_mut();
1607        check(unsafe { create(&mut desc) })?;
1608        let this = Self { desc };
1609        check(unsafe {
1610            set(
1611                this.desc,
1612                handle.handle,
1613                dropout,
1614                states.as_raw().0 as *mut core::ffi::c_void,
1615                states.byte_size(),
1616                seed,
1617            )
1618        })?;
1619        Ok(this)
1620    }
1621
1622    /// Raw descriptor.
1623    #[inline]
1624    pub fn as_raw(&self) -> cudnnDropoutDescriptor_t {
1625        self.desc
1626    }
1627}
1628
1629impl Drop for DropoutDescriptor {
1630    fn drop(&mut self) {
1631        if let Ok(c) = cudnn() {
1632            if let Ok(cu) = c.cudnn_destroy_dropout_descriptor() {
1633                let _ = unsafe { cu(self.desc) };
1634            }
1635        }
1636    }
1637}
1638
1639/// Size in bytes of the state buffer required for a dropout RNG.
1640pub fn dropout_states_size(handle: &Handle) -> Result<usize> {
1641    let c = cudnn()?;
1642    let cu = c.cudnn_dropout_get_states_size()?;
1643    let mut size = 0usize;
1644    check(unsafe { cu(handle.handle, &mut size) })?;
1645    Ok(size)
1646}
1647
1648/// Size in bytes of the reserve buffer required for dropout on `x`.
1649pub fn dropout_reserve_size(x: &TensorDescriptor) -> Result<usize> {
1650    let c = cudnn()?;
1651    let cu = c.cudnn_dropout_get_reserve_space_size()?;
1652    let mut size = 0usize;
1653    check(unsafe { cu(x.desc, &mut size) })?;
1654    Ok(size)
1655}
1656
1657/// Apply dropout to `x`, writing scaled survivors to `y` and the
1658/// keep/drop mask into `reserve` for the matching backward call.
1659#[allow(clippy::too_many_arguments)]
1660pub fn dropout_forward<T: DeviceRepr>(
1661    handle: &Handle,
1662    dropout: &DropoutDescriptor,
1663    x_desc: &TensorDescriptor,
1664    x: &DeviceBuffer<T>,
1665    y_desc: &TensorDescriptor,
1666    y: &mut DeviceBuffer<T>,
1667    reserve: &mut DeviceBuffer<u8>,
1668) -> Result<()> {
1669    let c = cudnn()?;
1670    let cu = c.cudnn_dropout_forward()?;
1671    check(unsafe {
1672        cu(
1673            handle.handle,
1674            dropout.desc,
1675            x_desc.desc,
1676            x.as_raw().0 as *const core::ffi::c_void,
1677            y_desc.desc,
1678            y.as_raw().0 as *mut core::ffi::c_void,
1679            reserve.as_raw().0 as *mut core::ffi::c_void,
1680            reserve.byte_size(),
1681        )
1682    })
1683}
1684
1685/// Backward dropout: replays the mask saved in `reserve` to produce `dx`
1686/// from `dy`. `reserve` must be the exact buffer populated by the
1687/// matching [`dropout_forward`] call.
1688#[allow(clippy::too_many_arguments)]
1689pub fn dropout_backward<T: DeviceRepr>(
1690    handle: &Handle,
1691    dropout: &DropoutDescriptor,
1692    dy_desc: &TensorDescriptor,
1693    dy: &DeviceBuffer<T>,
1694    dx_desc: &TensorDescriptor,
1695    dx: &mut DeviceBuffer<T>,
1696    reserve: &mut DeviceBuffer<u8>,
1697) -> Result<()> {
1698    let c = cudnn()?;
1699    let cu = c.cudnn_dropout_backward()?;
1700    check(unsafe {
1701        cu(
1702            handle.handle,
1703            dropout.desc,
1704            dy_desc.desc,
1705            dy.as_raw().0 as *const core::ffi::c_void,
1706            dx_desc.desc,
1707            dx.as_raw().0 as *mut core::ffi::c_void,
1708            reserve.as_raw().0 as *mut core::ffi::c_void,
1709            reserve.byte_size(),
1710        )
1711    })
1712}
1713
1714// ---- LRN ------------------------------------------------------------------
1715
1716/// Local Response Normalization descriptor: window size + α / β / k coefficients.
1717pub struct LrnDescriptor {
1718    desc: cudnnLRNDescriptor_t,
1719}
1720
1721unsafe impl Send for LrnDescriptor {}
1722
1723impl core::fmt::Debug for LrnDescriptor {
1724    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1725        f.debug_struct("LrnDescriptor")
1726            .field("desc", &self.desc)
1727            .finish_non_exhaustive()
1728    }
1729}
1730
1731impl LrnDescriptor {
1732    /// Build an LRN descriptor with window size `n` and the standard
1733    /// `(α, β, k)` formula coefficients.
1734    pub fn new(n: i32, alpha: f64, beta: f64, k: f64) -> Result<Self> {
1735        let cu = cudnn()?;
1736        let create = cu.cudnn_create_lrn_descriptor()?;
1737        let set = cu.cudnn_set_lrn_descriptor()?;
1738        let mut desc: cudnnLRNDescriptor_t = core::ptr::null_mut();
1739        check(unsafe { create(&mut desc) })?;
1740        let this = Self { desc };
1741        check(unsafe { set(this.desc, n, alpha, beta, k) })?;
1742        Ok(this)
1743    }
1744
1745    /// Raw descriptor.
1746    #[inline]
1747    pub fn as_raw(&self) -> cudnnLRNDescriptor_t {
1748        self.desc
1749    }
1750}
1751
1752impl Drop for LrnDescriptor {
1753    fn drop(&mut self) {
1754        if let Ok(c) = cudnn() {
1755            if let Ok(cu) = c.cudnn_destroy_lrn_descriptor() {
1756                let _ = unsafe { cu(self.desc) };
1757            }
1758        }
1759    }
1760}
1761
1762// ---- op-tensor / reduce / transform --------------------------------------
1763
1764/// Element-wise op for [`OpTensorDescriptor`] / [`op_tensor`].
1765#[derive(Copy, Clone, Debug, Eq, PartialEq)]
1766pub enum OpTensorOp {
1767    /// `c = a + b` element-wise.
1768    Add,
1769    /// `c = a * b` element-wise.
1770    Mul,
1771    /// `c = min(a, b)` element-wise.
1772    Min,
1773    /// `c = max(a, b)` element-wise.
1774    Max,
1775    /// `c = sqrt(a)` — `b` ignored.
1776    Sqrt,
1777    /// `c = 1 - a` — `b` ignored.
1778    Not,
1779}
1780
1781impl OpTensorOp {
1782    fn raw(self) -> cudnnOpTensorOp_t {
1783        match self {
1784            Self::Add => cudnnOpTensorOp_t::Add,
1785            Self::Mul => cudnnOpTensorOp_t::Mul,
1786            Self::Min => cudnnOpTensorOp_t::Min,
1787            Self::Max => cudnnOpTensorOp_t::Max,
1788            Self::Sqrt => cudnnOpTensorOp_t::Sqrt,
1789            Self::Not => cudnnOpTensorOp_t::Not,
1790        }
1791    }
1792}
1793
1794/// An op-tensor descriptor: binary element-wise op + compute dtype.
1795pub struct OpTensorDescriptor {
1796    desc: cudnnOpTensorDescriptor_t,
1797}
1798
1799unsafe impl Send for OpTensorDescriptor {}
1800
1801impl core::fmt::Debug for OpTensorDescriptor {
1802    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1803        f.debug_struct("OpTensorDescriptor")
1804            .field("desc", &self.desc)
1805            .finish_non_exhaustive()
1806    }
1807}
1808
1809impl OpTensorDescriptor {
1810    /// Build an op-tensor descriptor for `op` with `compute` as the
1811    /// accumulation dtype. NaN propagation defaults to `PropagateNan`.
1812    pub fn new(op: OpTensorOp, compute: DType) -> Result<Self> {
1813        let cu = cudnn()?;
1814        let create = cu.cudnn_create_op_tensor_descriptor()?;
1815        let set = cu.cudnn_set_op_tensor_descriptor()?;
1816        let mut desc: cudnnOpTensorDescriptor_t = core::ptr::null_mut();
1817        check(unsafe { create(&mut desc) })?;
1818        let this = Self { desc };
1819        check(unsafe {
1820            set(
1821                this.desc,
1822                op.raw(),
1823                compute.raw(),
1824                cudnnNanPropagation_t::PropagateNan,
1825            )
1826        })?;
1827        Ok(this)
1828    }
1829
1830    /// Raw descriptor.
1831    #[inline]
1832    pub fn as_raw(&self) -> cudnnOpTensorDescriptor_t {
1833        self.desc
1834    }
1835}
1836
1837impl Drop for OpTensorDescriptor {
1838    fn drop(&mut self) {
1839        if let Ok(c) = cudnn() {
1840            if let Ok(cu) = c.cudnn_destroy_op_tensor_descriptor() {
1841                let _ = unsafe { cu(self.desc) };
1842            }
1843        }
1844    }
1845}
1846
1847/// `C = alpha1 * op(A) + alpha2 * op(B) + beta * C` element-wise.
1848#[allow(clippy::too_many_arguments)]
1849pub fn op_tensor<T: DeviceRepr>(
1850    handle: &Handle,
1851    op: &OpTensorDescriptor,
1852    alpha1: f32,
1853    a_desc: &TensorDescriptor,
1854    a: &DeviceBuffer<T>,
1855    alpha2: f32,
1856    b_desc: &TensorDescriptor,
1857    b: &DeviceBuffer<T>,
1858    beta: f32,
1859    c_desc: &TensorDescriptor,
1860    c: &mut DeviceBuffer<T>,
1861) -> Result<()> {
1862    let cu_crate = cudnn()?;
1863    let cu = cu_crate.cudnn_op_tensor()?;
1864    check(unsafe {
1865        cu(
1866            handle.handle,
1867            op.desc,
1868            &alpha1 as *const f32 as *const core::ffi::c_void,
1869            a_desc.desc,
1870            a.as_raw().0 as *const core::ffi::c_void,
1871            &alpha2 as *const f32 as *const core::ffi::c_void,
1872            b_desc.desc,
1873            b.as_raw().0 as *const core::ffi::c_void,
1874            &beta as *const f32 as *const core::ffi::c_void,
1875            c_desc.desc,
1876            c.as_raw().0 as *mut core::ffi::c_void,
1877        )
1878    })
1879}
1880
1881/// Reduction op for [`ReduceTensorDescriptor`] / [`reduce_tensor`].
1882#[derive(Copy, Clone, Debug, Eq, PartialEq)]
1883pub enum ReduceOp {
1884    /// Sum of inputs.
1885    Add,
1886    /// Product of inputs.
1887    Mul,
1888    /// Minimum value.
1889    Min,
1890    /// Maximum value.
1891    Max,
1892    /// Maximum absolute value (`max(|x|)`).
1893    AbsMax,
1894    /// Arithmetic mean.
1895    Avg,
1896    /// L1 norm: `sum(|x|)`.
1897    Norm1,
1898    /// L2 norm: `sqrt(sum(x^2))`.
1899    Norm2,
1900    /// Product, skipping any zero inputs.
1901    MulNoZeros,
1902}
1903
1904impl ReduceOp {
1905    fn raw(self) -> cudnnReduceTensorOp_t {
1906        match self {
1907            Self::Add => cudnnReduceTensorOp_t::Add,
1908            Self::Mul => cudnnReduceTensorOp_t::Mul,
1909            Self::Min => cudnnReduceTensorOp_t::Min,
1910            Self::Max => cudnnReduceTensorOp_t::Max,
1911            Self::AbsMax => cudnnReduceTensorOp_t::Amax,
1912            Self::Avg => cudnnReduceTensorOp_t::Avg,
1913            Self::Norm1 => cudnnReduceTensorOp_t::Norm1,
1914            Self::Norm2 => cudnnReduceTensorOp_t::Norm2,
1915            Self::MulNoZeros => cudnnReduceTensorOp_t::MulNoZeros,
1916        }
1917    }
1918}
1919
1920/// A reduce-tensor descriptor: reduction op + compute dtype.
1921pub struct ReduceTensorDescriptor {
1922    desc: cudnnReduceTensorDescriptor_t,
1923}
1924
1925unsafe impl Send for ReduceTensorDescriptor {}
1926
1927impl core::fmt::Debug for ReduceTensorDescriptor {
1928    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1929        f.debug_struct("ReduceTensorDescriptor")
1930            .field("desc", &self.desc)
1931            .finish_non_exhaustive()
1932    }
1933}
1934
1935impl ReduceTensorDescriptor {
1936    /// Build a reduce-tensor descriptor for `op` with `compute` as the
1937    /// accumulation dtype. NaN propagation defaults to `PropagateNan`;
1938    /// indices are not returned (use the lower-level cuDNN API directly
1939    /// for arg-reductions).
1940    pub fn new(op: ReduceOp, compute: DType) -> Result<Self> {
1941        let cu = cudnn()?;
1942        let create = cu.cudnn_create_reduce_tensor_descriptor()?;
1943        let set = cu.cudnn_set_reduce_tensor_descriptor()?;
1944        let mut desc: cudnnReduceTensorDescriptor_t = core::ptr::null_mut();
1945        check(unsafe { create(&mut desc) })?;
1946        let this = Self { desc };
1947        check(unsafe {
1948            set(
1949                this.desc,
1950                op.raw(),
1951                compute.raw(),
1952                cudnnNanPropagation_t::PropagateNan,
1953                cudnnReduceTensorIndices_t::NoIndices,
1954                cudnnIndicesType_t::U32,
1955            )
1956        })?;
1957        Ok(this)
1958    }
1959
1960    /// Workspace bytes required to run [`reduce_tensor`] reducing `a` into `c`.
1961    pub fn workspace_size(
1962        &self,
1963        handle: &Handle,
1964        a: &TensorDescriptor,
1965        c: &TensorDescriptor,
1966    ) -> Result<usize> {
1967        let cu = cudnn()?;
1968        let q = cu.cudnn_get_reduction_workspace_size()?;
1969        let mut size = 0usize;
1970        check(unsafe { q(handle.handle, self.desc, a.desc, c.desc, &mut size) })?;
1971        Ok(size)
1972    }
1973
1974    /// Raw descriptor.
1975    #[inline]
1976    pub fn as_raw(&self) -> cudnnReduceTensorDescriptor_t {
1977        self.desc
1978    }
1979}
1980
1981impl Drop for ReduceTensorDescriptor {
1982    fn drop(&mut self) {
1983        if let Ok(c) = cudnn() {
1984            if let Ok(cu) = c.cudnn_destroy_reduce_tensor_descriptor() {
1985                let _ = unsafe { cu(self.desc) };
1986            }
1987        }
1988    }
1989}
1990
1991/// `C = alpha * reduce(A) + beta * C` over the axes where `A`'s extent is
1992/// preserved and `C`'s is 1.
1993#[allow(clippy::too_many_arguments)]
1994pub fn reduce_tensor<T: DeviceRepr>(
1995    handle: &Handle,
1996    reducer: &ReduceTensorDescriptor,
1997    workspace: &mut DeviceBuffer<u8>,
1998    alpha: f32,
1999    a_desc: &TensorDescriptor,
2000    a: &DeviceBuffer<T>,
2001    beta: f32,
2002    c_desc: &TensorDescriptor,
2003    c: &mut DeviceBuffer<T>,
2004) -> Result<()> {
2005    let cu_crate = cudnn()?;
2006    let cu = cu_crate.cudnn_reduce_tensor()?;
2007    check(unsafe {
2008        cu(
2009            handle.handle,
2010            reducer.desc,
2011            core::ptr::null_mut(),
2012            0,
2013            workspace.as_raw().0 as *mut core::ffi::c_void,
2014            workspace.byte_size(),
2015            &alpha as *const f32 as *const core::ffi::c_void,
2016            a_desc.desc,
2017            a.as_raw().0 as *const core::ffi::c_void,
2018            &beta as *const f32 as *const core::ffi::c_void,
2019            c_desc.desc,
2020            c.as_raw().0 as *mut core::ffi::c_void,
2021        )
2022    })
2023}
2024
2025/// `C = alpha * A + beta * C` with broadcast. Useful for adding a per-channel
2026/// bias to a feature map.
2027pub fn add_tensor<T: DeviceRepr>(
2028    handle: &Handle,
2029    alpha: f32,
2030    a_desc: &TensorDescriptor,
2031    a: &DeviceBuffer<T>,
2032    beta: f32,
2033    c_desc: &TensorDescriptor,
2034    c: &mut DeviceBuffer<T>,
2035) -> Result<()> {
2036    let cu_crate = cudnn()?;
2037    let cu = cu_crate.cudnn_add_tensor()?;
2038    check(unsafe {
2039        cu(
2040            handle.handle,
2041            &alpha as *const f32 as *const core::ffi::c_void,
2042            a_desc.desc,
2043            a.as_raw().0 as *const core::ffi::c_void,
2044            &beta as *const f32 as *const core::ffi::c_void,
2045            c_desc.desc,
2046            c.as_raw().0 as *mut core::ffi::c_void,
2047        )
2048    })
2049}
2050
2051// ---- backend (Graph) API --------------------------------------------------
2052
2053/// Thin wrapper over a `cudnnBackendDescriptor_t`. Used to build Graph-API
2054/// operation graphs and execution plans. Callers set attributes with
2055/// [`BackendDescriptor::set_attribute_raw`] using the constants in
2056/// [`baracuda_cudnn_sys::cudnnBackendAttributeName_t`] /
2057/// [`baracuda_cudnn_sys::cudnnBackendAttributeType_t`].
2058pub struct BackendDescriptor {
2059    desc: cudnnBackendDescriptor_t,
2060    finalized: bool,
2061}
2062
2063unsafe impl Send for BackendDescriptor {}
2064
2065impl core::fmt::Debug for BackendDescriptor {
2066    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2067        f.debug_struct("BackendDescriptor")
2068            .field("desc", &self.desc)
2069            .field("finalized", &self.finalized)
2070            .finish()
2071    }
2072}
2073
2074impl BackendDescriptor {
2075    /// Allocate and initialize a backend descriptor of the given `kind`.
2076    /// The descriptor is unfinalized; set attributes via
2077    /// [`set_attribute_raw`](Self::set_attribute_raw) then call
2078    /// [`finalize`](Self::finalize).
2079    pub fn new(kind: cudnnBackendDescriptorType_t) -> Result<Self> {
2080        let cu = cudnn()?;
2081        let create = cu.cudnn_backend_create_descriptor()?;
2082        let init = cu.cudnn_backend_initialize()?;
2083        let mut desc: cudnnBackendDescriptor_t = core::ptr::null_mut();
2084        check(unsafe { create(kind, &mut desc) })?;
2085        let this = Self {
2086            desc,
2087            finalized: false,
2088        };
2089        check(unsafe { init(this.desc) })?;
2090        Ok(this)
2091    }
2092
2093    /// Set an attribute by name/type. `element_count` is the number of
2094    /// elements in `array_of_elements` (not byte count).
2095    ///
2096    /// # Safety
2097    /// `array_of_elements` must point to valid data matching the attribute's
2098    /// expected type and count.
2099    pub unsafe fn set_attribute_raw(
2100        &self,
2101        name: cudnnBackendAttributeName_t,
2102        ty: cudnnBackendAttributeType_t,
2103        element_count: i64,
2104        array_of_elements: *const core::ffi::c_void,
2105    ) -> Result<()> { unsafe {
2106        let cu = cudnn()?;
2107        let f = cu.cudnn_backend_set_attribute()?;
2108        check(f(self.desc, name, ty, element_count, array_of_elements))
2109    }}
2110
2111    /// Lock in the descriptor's attributes. Idempotent — repeated calls
2112    /// are no-ops.
2113    pub fn finalize(&mut self) -> Result<()> {
2114        if self.finalized {
2115            return Ok(());
2116        }
2117        let cu = cudnn()?;
2118        let f = cu.cudnn_backend_finalize()?;
2119        check(unsafe { f(self.desc) })?;
2120        self.finalized = true;
2121        Ok(())
2122    }
2123
2124    /// Execute an execution-plan descriptor. `self` should be the plan
2125    /// descriptor; `variant_pack` provides tensor addresses + workspace.
2126    pub fn execute(&self, handle: &Handle, variant_pack: &BackendDescriptor) -> Result<()> {
2127        let cu = cudnn()?;
2128        let f = cu.cudnn_backend_execute()?;
2129        check(unsafe { f(handle.handle, self.desc, variant_pack.desc) })
2130    }
2131
2132    /// Raw descriptor.
2133    #[inline]
2134    pub fn as_raw(&self) -> cudnnBackendDescriptor_t {
2135        self.desc
2136    }
2137}
2138
2139impl Drop for BackendDescriptor {
2140    fn drop(&mut self) {
2141        if let Ok(c) = cudnn() {
2142            if let Ok(cu) = c.cudnn_backend_destroy_descriptor() {
2143                let _ = unsafe { cu(self.desc) };
2144            }
2145        }
2146    }
2147}
2148
2149/// Re-export the backend attribute enums so callers don't have to reach
2150/// into the sys crate.
2151pub use baracuda_cudnn_sys::{
2152    cudnnBackendAttributeName_t as BackendAttrName,
2153    cudnnBackendAttributeType_t as BackendAttrType,
2154    cudnnBackendDescriptorType_t as BackendDescType,
2155};
2156
2157// ---- CTC loss ------------------------------------------------------------
2158
2159use baracuda_cudnn_sys::cudnnCTCLossDescriptor_t;
2160
2161/// CTC (Connectionist Temporal Classification) loss descriptor.
2162pub struct CtcLossDescriptor {
2163    desc: cudnnCTCLossDescriptor_t,
2164}
2165
2166unsafe impl Send for CtcLossDescriptor {}
2167
2168impl core::fmt::Debug for CtcLossDescriptor {
2169    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2170        f.debug_struct("CtcLossDescriptor")
2171            .field("desc", &self.desc)
2172            .finish_non_exhaustive()
2173    }
2174}
2175
2176impl CtcLossDescriptor {
2177    /// Build a CTC-loss descriptor with `compute` as the accumulation dtype.
2178    pub fn new(compute: DType) -> Result<Self> {
2179        let cu = cudnn()?;
2180        let create = cu.cudnn_create_ctc_loss_descriptor()?;
2181        let set = cu.cudnn_set_ctc_loss_descriptor()?;
2182        let mut desc: cudnnCTCLossDescriptor_t = core::ptr::null_mut();
2183        check(unsafe { create(&mut desc) })?;
2184        let this = Self { desc };
2185        check(unsafe { set(this.desc, compute.raw()) })?;
2186        Ok(this)
2187    }
2188
2189    /// Raw descriptor.
2190    #[inline]
2191    pub fn as_raw(&self) -> cudnnCTCLossDescriptor_t {
2192        self.desc
2193    }
2194}
2195
2196impl Drop for CtcLossDescriptor {
2197    fn drop(&mut self) {
2198        if let Ok(c) = cudnn() {
2199            if let Ok(cu) = c.cudnn_destroy_ctc_loss_descriptor() {
2200                let _ = unsafe { cu(self.desc) };
2201            }
2202        }
2203    }
2204}
2205
2206/// Bytes of scratch workspace needed for [`ctc_loss`].
2207#[allow(clippy::too_many_arguments)]
2208pub fn ctc_loss_workspace_size(
2209    handle: &Handle,
2210    probs: &TensorDescriptor,
2211    gradients: &TensorDescriptor,
2212    labels: &[i32],
2213    label_lengths: &[i32],
2214    input_lengths: &[i32],
2215    algo: i32,
2216    desc: &CtcLossDescriptor,
2217) -> Result<usize> {
2218    let cu = cudnn()?;
2219    let q = cu.cudnn_get_ctc_loss_workspace_size()?;
2220    let mut size = 0usize;
2221    check(unsafe {
2222        q(
2223            handle.handle,
2224            probs.desc,
2225            gradients.desc,
2226            labels.as_ptr(),
2227            label_lengths.as_ptr(),
2228            input_lengths.as_ptr(),
2229            algo,
2230            desc.desc,
2231            &mut size,
2232        )
2233    })?;
2234    Ok(size)
2235}
2236
2237/// CTC (Connectionist Temporal Classification) loss.
2238#[allow(clippy::too_many_arguments)]
2239pub fn ctc_loss<T: DeviceRepr>(
2240    handle: &Handle,
2241    probs_desc: &TensorDescriptor,
2242    probs: &DeviceBuffer<T>,
2243    labels: &[i32],
2244    label_lengths: &[i32],
2245    input_lengths: &[i32],
2246    costs: &mut DeviceBuffer<T>,
2247    gradients_desc: &TensorDescriptor,
2248    gradients: &mut DeviceBuffer<T>,
2249    algo: i32,
2250    desc: &CtcLossDescriptor,
2251    workspace: &mut DeviceBuffer<u8>,
2252) -> Result<()> {
2253    let c = cudnn()?;
2254    let cu = c.cudnn_ctc_loss()?;
2255    check(unsafe {
2256        cu(
2257            handle.handle,
2258            probs_desc.desc,
2259            probs.as_raw().0 as *const core::ffi::c_void,
2260            labels.as_ptr(),
2261            label_lengths.as_ptr(),
2262            input_lengths.as_ptr(),
2263            costs.as_raw().0 as *mut core::ffi::c_void,
2264            gradients_desc.desc,
2265            gradients.as_raw().0 as *mut core::ffi::c_void,
2266            algo,
2267            desc.desc,
2268            workspace.as_raw().0 as *mut core::ffi::c_void,
2269            workspace.byte_size(),
2270        )
2271    })
2272}
2273
2274// ---- Spatial transformer ------------------------------------------------
2275
2276use baracuda_cudnn_sys::cudnnSpatialTransformerDescriptor_t;
2277
2278/// Spatial-transformer descriptor: sampler kind + output shape.
2279pub struct SpatialTransformerDescriptor {
2280    desc: cudnnSpatialTransformerDescriptor_t,
2281}
2282
2283unsafe impl Send for SpatialTransformerDescriptor {}
2284
2285impl core::fmt::Debug for SpatialTransformerDescriptor {
2286    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2287        f.debug_struct("SpatialTransformerDescriptor")
2288            .field("desc", &self.desc)
2289            .finish_non_exhaustive()
2290    }
2291}
2292
2293impl SpatialTransformerDescriptor {
2294    /// `sampler_type` matches `CUDNN_SAMPLER_BILINEAR = 0`.
2295    pub fn new(sampler_type: i32, dtype: DType, dims: &[i32]) -> Result<Self> {
2296        let cu = cudnn()?;
2297        let create = cu.cudnn_create_spatial_transformer_descriptor()?;
2298        let set = cu.cudnn_set_spatial_transformer_nd_descriptor()?;
2299        let mut desc: cudnnSpatialTransformerDescriptor_t = core::ptr::null_mut();
2300        check(unsafe { create(&mut desc) })?;
2301        let this = Self { desc };
2302        check(unsafe {
2303            set(
2304                this.desc,
2305                sampler_type,
2306                dtype.raw(),
2307                dims.len() as core::ffi::c_int,
2308                dims.as_ptr(),
2309            )
2310        })?;
2311        Ok(this)
2312    }
2313
2314    /// Raw descriptor.
2315    #[inline]
2316    pub fn as_raw(&self) -> cudnnSpatialTransformerDescriptor_t {
2317        self.desc
2318    }
2319}
2320
2321impl Drop for SpatialTransformerDescriptor {
2322    fn drop(&mut self) {
2323        if let Ok(c) = cudnn() {
2324            if let Ok(cu) = c.cudnn_destroy_spatial_transformer_descriptor() {
2325                let _ = unsafe { cu(self.desc) };
2326            }
2327        }
2328    }
2329}
2330
2331/// Compute the sampling grid from the affine transform `theta`.
2332pub fn spatial_tf_grid_generator<T: DeviceRepr>(
2333    handle: &Handle,
2334    st: &SpatialTransformerDescriptor,
2335    theta: &DeviceBuffer<T>,
2336    grid: &mut DeviceBuffer<T>,
2337) -> Result<()> {
2338    let c = cudnn()?;
2339    let cu = c.cudnn_spatial_tf_grid_generator_forward()?;
2340    check(unsafe {
2341        cu(
2342            handle.handle,
2343            st.desc,
2344            theta.as_raw().0 as *const core::ffi::c_void,
2345            grid.as_raw().0 as *mut core::ffi::c_void,
2346        )
2347    })
2348}
2349
2350/// Bilinearly sample `x` at `grid` points to produce `y`.
2351#[allow(clippy::too_many_arguments)]
2352pub fn spatial_tf_sampler<T: DeviceRepr>(
2353    handle: &Handle,
2354    st: &SpatialTransformerDescriptor,
2355    alpha: f32,
2356    x_desc: &TensorDescriptor,
2357    x: &DeviceBuffer<T>,
2358    grid: &DeviceBuffer<T>,
2359    beta: f32,
2360    y_desc: &TensorDescriptor,
2361    y: &mut DeviceBuffer<T>,
2362) -> Result<()> {
2363    let c = cudnn()?;
2364    let cu = c.cudnn_spatial_tf_sampler_forward()?;
2365    check(unsafe {
2366        cu(
2367            handle.handle,
2368            st.desc,
2369            &alpha as *const f32 as *const core::ffi::c_void,
2370            x_desc.desc,
2371            x.as_raw().0 as *const core::ffi::c_void,
2372            grid.as_raw().0 as *const core::ffi::c_void,
2373            &beta as *const f32 as *const core::ffi::c_void,
2374            y_desc.desc,
2375            y.as_raw().0 as *mut core::ffi::c_void,
2376        )
2377    })
2378}
2379
2380// ============================================================================
2381// Tier 1 — math type / reorder type / fused conv+bias+act / activation back
2382// ============================================================================
2383
2384/// Math-type selector for [`ConvolutionDescriptor::set_math_type`] —
2385/// controls tensor-core eligibility.
2386#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
2387pub enum MathType {
2388    /// Standard FMA-only math; tensor cores not used.
2389    #[default]
2390    Default,
2391    /// Allow tensor-core math (Volta+).
2392    TensorOp,
2393    /// Allow tensor-core math with implicit half-precision conversion.
2394    TensorOpAllowConversion,
2395    /// Strict FMA-only.
2396    FmaOnly,
2397}
2398
2399impl MathType {
2400    pub(crate) fn raw(self) -> cudnnMathType_t {
2401        match self {
2402            MathType::Default => cudnnMathType_t::DefaultMath,
2403            MathType::TensorOp => cudnnMathType_t::TensorOpMath,
2404            MathType::TensorOpAllowConversion => cudnnMathType_t::TensorOpMathAllowConversion,
2405            MathType::FmaOnly => cudnnMathType_t::FmaMath,
2406        }
2407    }
2408    pub(crate) fn from_raw(raw: cudnnMathType_t) -> Self {
2409        match raw {
2410            cudnnMathType_t::DefaultMath => MathType::Default,
2411            cudnnMathType_t::TensorOpMath => MathType::TensorOp,
2412            cudnnMathType_t::TensorOpMathAllowConversion => MathType::TensorOpAllowConversion,
2413            cudnnMathType_t::FmaMath => MathType::FmaOnly,
2414        }
2415    }
2416}
2417
2418/// Filter / bias reorder selector for INT8 quantized inference paths.
2419#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
2420pub enum ReorderType {
2421    /// cuDNN-chosen reorder layout. **Default.**
2422    #[default]
2423    Default,
2424    /// Leave the filter / bias buffers in their original layout.
2425    None,
2426}
2427
2428impl ReorderType {
2429    pub(crate) fn raw(self) -> cudnnReorderType_t {
2430        match self {
2431            ReorderType::Default => cudnnReorderType_t::DefaultReorder,
2432            ReorderType::None => cudnnReorderType_t::NoReorder,
2433        }
2434    }
2435    pub(crate) fn from_raw(raw: cudnnReorderType_t) -> Self {
2436        match raw {
2437            cudnnReorderType_t::DefaultReorder => ReorderType::Default,
2438            cudnnReorderType_t::NoReorder => ReorderType::None,
2439        }
2440    }
2441}
2442
2443/// Pre-process filter / bias buffers for INT8 inference.
2444///
2445/// # Safety
2446/// Output buffers must have at least the same byte size as the inputs.
2447/// `bias_data` / `reordered_bias` may be null when `reorder_bias` is false.
2448#[allow(clippy::too_many_arguments)]
2449pub unsafe fn reorder_filter_and_bias(
2450    handle: &Handle,
2451    filter_desc: &FilterDescriptor,
2452    reorder: ReorderType,
2453    filter_data: *const core::ffi::c_void,
2454    reordered_filter: *mut core::ffi::c_void,
2455    reorder_bias: bool,
2456    bias_data: *const core::ffi::c_void,
2457    reordered_bias: *mut core::ffi::c_void,
2458) -> Result<()> { unsafe {
2459    let c = cudnn()?;
2460    let f = c.cudnn_reorder_filter_and_bias()?;
2461    check(f(
2462        handle.handle, filter_desc.desc, reorder.raw(),
2463        filter_data, reordered_filter,
2464        reorder_bias as core::ffi::c_int, bias_data, reordered_bias,
2465    ))
2466}}
2467
2468/// Fused convolution + bias + activation forward:
2469/// `Y = activation(alpha1 * conv(X, W) + alpha2 * Z + bias)`.
2470/// `Z` may alias `Y` for in-place residual add.
2471#[allow(clippy::too_many_arguments)]
2472pub fn convolution_bias_activation_forward<T: DeviceRepr>(
2473    handle: &Handle,
2474    alpha1: f32,
2475    x_desc: &TensorDescriptor, x: &DeviceBuffer<T>,
2476    w_desc: &FilterDescriptor, w: &DeviceBuffer<T>,
2477    conv: &ConvolutionDescriptor,
2478    algo: FwdAlgo,
2479    workspace: &mut DeviceBuffer<u8>,
2480    alpha2: f32,
2481    z_desc: &TensorDescriptor, z: &DeviceBuffer<T>,
2482    bias_desc: &TensorDescriptor, bias: &DeviceBuffer<T>,
2483    activation: &ActivationDescriptor,
2484    y_desc: &TensorDescriptor, y: &mut DeviceBuffer<T>,
2485) -> Result<()> {
2486    let c = cudnn()?;
2487    let cu = c.cudnn_convolution_bias_activation_forward()?;
2488    check(unsafe {
2489        cu(
2490            handle.handle,
2491            &alpha1 as *const f32 as *const core::ffi::c_void,
2492            x_desc.desc, x.as_raw().0 as *const core::ffi::c_void,
2493            w_desc.desc, w.as_raw().0 as *const core::ffi::c_void,
2494            conv.desc, algo.raw(),
2495            workspace.as_raw().0 as *mut core::ffi::c_void, workspace.byte_size(),
2496            &alpha2 as *const f32 as *const core::ffi::c_void,
2497            z_desc.desc, z.as_raw().0 as *const core::ffi::c_void,
2498            bias_desc.desc, bias.as_raw().0 as *const core::ffi::c_void,
2499            activation.desc,
2500            y_desc.desc, y.as_raw().0 as *mut core::ffi::c_void,
2501        )
2502    })
2503}
2504
2505/// `dx = alpha * activation_backward(y, dy, x) + beta * dx`.
2506#[allow(clippy::too_many_arguments)]
2507pub fn activation_backward<T: DeviceRepr>(
2508    handle: &Handle,
2509    activation: &ActivationDescriptor,
2510    alpha: f32,
2511    y_desc: &TensorDescriptor, y: &DeviceBuffer<T>,
2512    dy_desc: &TensorDescriptor, dy: &DeviceBuffer<T>,
2513    x_desc: &TensorDescriptor, x: &DeviceBuffer<T>,
2514    beta: f32,
2515    dx_desc: &TensorDescriptor, dx: &mut DeviceBuffer<T>,
2516) -> Result<()> {
2517    let c = cudnn()?;
2518    let cu = c.cudnn_activation_backward()?;
2519    check(unsafe {
2520        cu(
2521            handle.handle, activation.desc,
2522            &alpha as *const f32 as *const core::ffi::c_void,
2523            y_desc.desc, y.as_raw().0 as *const core::ffi::c_void,
2524            dy_desc.desc, dy.as_raw().0 as *const core::ffi::c_void,
2525            x_desc.desc, x.as_raw().0 as *const core::ffi::c_void,
2526            &beta as *const f32 as *const core::ffi::c_void,
2527            dx_desc.desc, dx.as_raw().0 as *mut core::ffi::c_void,
2528        )
2529    })
2530}
2531
2532/// Cross-channel LRN backward.
2533#[allow(clippy::too_many_arguments)]
2534pub fn lrn_cross_channel_backward<T: DeviceRepr>(
2535    handle: &Handle, lrn: &LrnDescriptor, mode: i32,
2536    alpha: f32,
2537    y_desc: &TensorDescriptor, y: &DeviceBuffer<T>,
2538    dy_desc: &TensorDescriptor, dy: &DeviceBuffer<T>,
2539    x_desc: &TensorDescriptor, x: &DeviceBuffer<T>,
2540    beta: f32,
2541    dx_desc: &TensorDescriptor, dx: &mut DeviceBuffer<T>,
2542) -> Result<()> {
2543    let c = cudnn()?;
2544    let cu = c.cudnn_lrn_cross_channel_backward()?;
2545    check(unsafe {
2546        cu(
2547            handle.handle, lrn.desc, mode,
2548            &alpha as *const f32 as *const core::ffi::c_void,
2549            y_desc.desc, y.as_raw().0 as *const core::ffi::c_void,
2550            dy_desc.desc, dy.as_raw().0 as *const core::ffi::c_void,
2551            x_desc.desc, x.as_raw().0 as *const core::ffi::c_void,
2552            &beta as *const f32 as *const core::ffi::c_void,
2553            dx_desc.desc, dx.as_raw().0 as *mut core::ffi::c_void,
2554        )
2555    })
2556}
2557
2558/// Bytes of indices buffer required for index-returning reductions.
2559pub fn reduction_indices_size(
2560    handle: &Handle,
2561    reducer: &ReduceTensorDescriptor,
2562    a: &TensorDescriptor,
2563    c: &TensorDescriptor,
2564) -> Result<usize> {
2565    let cu = cudnn()?;
2566    let q = cu.cudnn_get_reduction_indices_size()?;
2567    let mut size = 0usize;
2568    check(unsafe { q(handle.handle, reducer.desc, a.desc, c.desc, &mut size) })?;
2569    Ok(size)
2570}
2571
2572impl ActivationDescriptor {
2573    /// Set the β parameter on a Swish activation.
2574    pub fn set_swish_beta(&self, beta: f64) -> Result<()> {
2575        let c = cudnn()?;
2576        let f = c.cudnn_set_activation_descriptor_swish_beta()?;
2577        check(unsafe { f(self.desc, beta) })
2578    }
2579    /// Read back the Swish β parameter.
2580    pub fn swish_beta(&self) -> Result<f64> {
2581        let c = cudnn()?;
2582        let f = c.cudnn_get_activation_descriptor_swish_beta()?;
2583        let mut b: f64 = 0.0;
2584        check(unsafe { f(self.desc, &mut b) })?;
2585        Ok(b)
2586    }
2587}
2588
2589// ============================================================================
2590// Tier 2 — Algorithm finders / pickers
2591// ============================================================================
2592
2593/// Per-algorithm performance record returned by the forward-convolution finders.
2594pub use baracuda_cudnn_sys::cudnnConvolutionFwdAlgoPerf_t as FwdAlgoPerf;
2595/// Per-algorithm performance record returned by the backward-data convolution finders.
2596pub use baracuda_cudnn_sys::cudnnConvolutionBwdDataAlgoPerf_t as BwdDataAlgoPerf;
2597/// Per-algorithm performance record returned by the backward-filter convolution finders.
2598pub use baracuda_cudnn_sys::cudnnConvolutionBwdFilterAlgoPerf_t as BwdFilterAlgoPerf;
2599
2600/// Heuristic-pick the top-N forward-convolution algorithms (cheap; doesn't run them).
2601pub fn get_convolution_forward_algorithm(
2602    handle: &Handle,
2603    src: &TensorDescriptor, filter: &FilterDescriptor,
2604    conv: &ConvolutionDescriptor, dst: &TensorDescriptor,
2605    requested: i32,
2606) -> Result<Vec<FwdAlgoPerf>> {
2607    let cu = cudnn()?;
2608    let f = cu.cudnn_get_convolution_forward_algorithm_v7()?;
2609    let mut returned: core::ffi::c_int = 0;
2610    let mut buf: Vec<FwdAlgoPerf> = Vec::with_capacity(requested as usize);
2611    let raw = unsafe {
2612        f(handle.handle, src.desc, filter.desc, conv.desc, dst.desc,
2613          requested, &mut returned, buf.as_mut_ptr())
2614    };
2615    check(raw)?;
2616    unsafe { buf.set_len(returned as usize); }
2617    Ok(buf)
2618}
2619
2620/// Run all candidate forward-convolution algorithms and return measured runtimes.
2621pub fn find_convolution_forward_algorithm(
2622    handle: &Handle,
2623    src: &TensorDescriptor, filter: &FilterDescriptor,
2624    conv: &ConvolutionDescriptor, dst: &TensorDescriptor,
2625    requested: i32,
2626) -> Result<Vec<FwdAlgoPerf>> {
2627    let cu = cudnn()?;
2628    let f = cu.cudnn_find_convolution_forward_algorithm()?;
2629    let mut returned: core::ffi::c_int = 0;
2630    let mut buf: Vec<FwdAlgoPerf> = Vec::with_capacity(requested as usize);
2631    let raw = unsafe {
2632        f(handle.handle, src.desc, filter.desc, conv.desc, dst.desc,
2633          requested, &mut returned, buf.as_mut_ptr())
2634    };
2635    check(raw)?;
2636    unsafe { buf.set_len(returned as usize); }
2637    Ok(buf)
2638}
2639
2640/// Heuristic-pick backward-data convolution algorithms.
2641pub fn get_convolution_backward_data_algorithm(
2642    handle: &Handle,
2643    filter: &FilterDescriptor, diff: &TensorDescriptor,
2644    conv: &ConvolutionDescriptor, grad: &TensorDescriptor,
2645    requested: i32,
2646) -> Result<Vec<BwdDataAlgoPerf>> {
2647    let cu = cudnn()?;
2648    let f = cu.cudnn_get_convolution_backward_data_algorithm_v7()?;
2649    let mut returned: core::ffi::c_int = 0;
2650    let mut buf: Vec<BwdDataAlgoPerf> = Vec::with_capacity(requested as usize);
2651    let raw = unsafe {
2652        f(handle.handle, filter.desc, diff.desc, conv.desc, grad.desc,
2653          requested, &mut returned, buf.as_mut_ptr())
2654    };
2655    check(raw)?;
2656    unsafe { buf.set_len(returned as usize); }
2657    Ok(buf)
2658}
2659
2660/// Heuristic-pick backward-filter convolution algorithms.
2661pub fn get_convolution_backward_filter_algorithm(
2662    handle: &Handle,
2663    src: &TensorDescriptor, diff: &TensorDescriptor,
2664    conv: &ConvolutionDescriptor, grad: &FilterDescriptor,
2665    requested: i32,
2666) -> Result<Vec<BwdFilterAlgoPerf>> {
2667    let cu = cudnn()?;
2668    let f = cu.cudnn_get_convolution_backward_filter_algorithm_v7()?;
2669    let mut returned: core::ffi::c_int = 0;
2670    let mut buf: Vec<BwdFilterAlgoPerf> = Vec::with_capacity(requested as usize);
2671    let raw = unsafe {
2672        f(handle.handle, src.desc, diff.desc, conv.desc, grad.desc,
2673          requested, &mut returned, buf.as_mut_ptr())
2674    };
2675    check(raw)?;
2676    unsafe { buf.set_len(returned as usize); }
2677    Ok(buf)
2678}
2679
2680// ============================================================================
2681// Tier 3 — Generic Normalization API enums (cuDNN 8+) + workspace queries
2682// ============================================================================
2683
2684/// Generic-normalization parameter sharing pattern (cuDNN 8+).
2685#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
2686pub enum NormMode {
2687    /// One scale/bias per `(C, H, W)` cell.
2688    PerActivation,
2689    /// One scale/bias per `C` channel (shared spatially). **Default.**
2690    #[default]
2691    PerChannel,
2692}
2693impl NormMode {
2694    fn raw(self) -> cudnnNormMode_t {
2695        match self {
2696            NormMode::PerActivation => cudnnNormMode_t::PerActivation,
2697            NormMode::PerChannel => cudnnNormMode_t::PerChannel,
2698        }
2699    }
2700}
2701
2702/// Generic-normalization kernel selector.
2703#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
2704pub enum NormAlgo {
2705    /// Standard normalization kernel. **Default.**
2706    #[default]
2707    Standard,
2708    /// Persistent-mode fast kernel where available.
2709    Persist,
2710}
2711impl NormAlgo {
2712    fn raw(self) -> cudnnNormAlgo_t {
2713        match self {
2714            NormAlgo::Standard => cudnnNormAlgo_t::Standard,
2715            NormAlgo::Persist => cudnnNormAlgo_t::Persist,
2716        }
2717    }
2718}
2719
2720/// Optional fused op for the generic-normalization API.
2721#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
2722pub enum NormOp {
2723    /// Plain normalization, no fused activation. **Default.**
2724    #[default]
2725    Norm,
2726    /// Fused `activation(norm(x))`.
2727    NormActivation,
2728    /// Fused `activation(norm(x) + z)` for residual add.
2729    NormAddActivation,
2730}
2731impl NormOp {
2732    fn raw(self) -> cudnnNormOps_t {
2733        match self {
2734            NormOp::Norm => cudnnNormOps_t::Norm,
2735            NormOp::NormActivation => cudnnNormOps_t::NormActivation,
2736            NormOp::NormAddActivation => cudnnNormOps_t::NormAddActivation,
2737        }
2738    }
2739}
2740
2741/// Optional fused op for the `*Ex` BatchNorm variants.
2742#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
2743pub enum BnOp {
2744    /// Plain batch normalization, no fused activation. **Default.**
2745    #[default]
2746    Bn,
2747    /// Fused `activation(bn(x))`.
2748    BnActivation,
2749    /// Fused `activation(bn(x) + z)` for residual add.
2750    BnAddActivation,
2751}
2752impl BnOp {
2753    fn raw(self) -> cudnnBatchNormOps_t {
2754        match self {
2755            BnOp::Bn => cudnnBatchNormOps_t::Bn,
2756            BnOp::BnActivation => cudnnBatchNormOps_t::BnActivation,
2757            BnOp::BnAddActivation => cudnnBatchNormOps_t::BnAddActivation,
2758        }
2759    }
2760}
2761
2762/// Workspace bytes for [`batch_normalization_forward_training_ex`].
2763#[allow(clippy::too_many_arguments)]
2764pub fn batch_normalization_forward_training_ex_workspace_size(
2765    handle: &Handle,
2766    mode: BatchNormMode, bn_ops: BnOp,
2767    x: &TensorDescriptor, z: &TensorDescriptor, y: &TensorDescriptor,
2768    bn_smbv: &TensorDescriptor, activation: &ActivationDescriptor,
2769) -> Result<usize> {
2770    let cu = cudnn()?;
2771    let f = cu.cudnn_get_batch_normalization_forward_training_ex_workspace_size()?;
2772    let mut size = 0usize;
2773    check(unsafe {
2774        f(handle.handle, mode.raw(), bn_ops.raw(),
2775          x.desc, z.desc, y.desc, bn_smbv.desc, activation.desc, &mut size)
2776    })?;
2777    Ok(size)
2778}
2779
2780/// Workspace bytes for [`batch_normalization_backward_ex`].
2781#[allow(clippy::too_many_arguments)]
2782pub fn batch_normalization_backward_ex_workspace_size(
2783    handle: &Handle,
2784    mode: BatchNormMode, bn_ops: BnOp,
2785    x: &TensorDescriptor, y: &TensorDescriptor, dy: &TensorDescriptor,
2786    dz: &TensorDescriptor, dx: &TensorDescriptor,
2787    d_bn_scale_bias: &TensorDescriptor, activation: &ActivationDescriptor,
2788) -> Result<usize> {
2789    let cu = cudnn()?;
2790    let f = cu.cudnn_get_batch_normalization_backward_ex_workspace_size()?;
2791    let mut size = 0usize;
2792    check(unsafe {
2793        f(handle.handle, mode.raw(), bn_ops.raw(),
2794          x.desc, y.desc, dy.desc, dz.desc, dx.desc,
2795          d_bn_scale_bias.desc, activation.desc, &mut size)
2796    })?;
2797    Ok(size)
2798}
2799
2800/// Reserve-space bytes for the `*Ex` BatchNorm pair.
2801pub fn batch_normalization_training_ex_reserve_space_size(
2802    handle: &Handle,
2803    mode: BatchNormMode, bn_ops: BnOp,
2804    activation: &ActivationDescriptor, x: &TensorDescriptor,
2805) -> Result<usize> {
2806    let cu = cudnn()?;
2807    let f = cu.cudnn_get_batch_normalization_training_ex_reserve_space_size()?;
2808    let mut size = 0usize;
2809    check(unsafe {
2810        f(handle.handle, mode.raw(), bn_ops.raw(), activation.desc, x.desc, &mut size)
2811    })?;
2812    Ok(size)
2813}
2814
2815// ============================================================================
2816// Tier 4 — RNN v8 + companion descriptors
2817// ============================================================================
2818
2819/// Owned RNN descriptor.
2820pub struct RnnDescriptor {
2821    desc: baracuda_cudnn_sys::cudnnRNNDescriptor_t,
2822}
2823unsafe impl Send for RnnDescriptor {}
2824impl core::fmt::Debug for RnnDescriptor {
2825    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2826        f.debug_struct("RnnDescriptor").field("desc", &self.desc).finish_non_exhaustive()
2827    }
2828}
2829impl RnnDescriptor {
2830    /// Allocate an empty RNN descriptor. Configure it with
2831    /// [`set_v8`](Self::set_v8) before use.
2832    pub fn new() -> Result<Self> {
2833        let c = cudnn()?;
2834        let create = c.cudnn_create_rnn_descriptor()?;
2835        let mut desc: baracuda_cudnn_sys::cudnnRNNDescriptor_t = core::ptr::null_mut();
2836        check(unsafe { create(&mut desc) })?;
2837        Ok(Self { desc })
2838    }
2839
2840    /// Configure with the v8 setup. After this, call
2841    /// [`build_rnn_dynamic`] to bind a specific minibatch size.
2842    #[allow(clippy::too_many_arguments)]
2843    pub fn set_v8(
2844        &self,
2845        algo: i32, cell_mode: i32, bias_mode: i32,
2846        dir_mode: i32, input_mode: i32,
2847        data_type: DType, math_prec: DType, math_type: MathType,
2848        input_size: i32, hidden_size: i32, proj_size: i32, num_layers: i32,
2849        dropout: &DropoutDescriptor, aux_flags: u32,
2850    ) -> Result<()> {
2851        use baracuda_cudnn_sys::{cudnnDirectionMode_t, cudnnRNNAlgo_t, cudnnRNNInputMode_t, cudnnRNNMode_t};
2852        let c = cudnn()?;
2853        let f = c.cudnn_set_rnn_descriptor_v8()?;
2854        let algo = match algo {
2855            0 => cudnnRNNAlgo_t::Standard,
2856            1 => cudnnRNNAlgo_t::PersistStatic,
2857            2 => cudnnRNNAlgo_t::PersistDynamic,
2858            _ => cudnnRNNAlgo_t::PersistStaticSmallH,
2859        };
2860        let cell = match cell_mode {
2861            0 => cudnnRNNMode_t::ReluRnn,
2862            1 => cudnnRNNMode_t::TanhRnn,
2863            2 => cudnnRNNMode_t::Lstm,
2864            _ => cudnnRNNMode_t::Gru,
2865        };
2866        let dir = if dir_mode == 1 { cudnnDirectionMode_t::Bidirectional } else { cudnnDirectionMode_t::Unidirectional };
2867        let im = if input_mode == 1 { cudnnRNNInputMode_t::SkipInput } else { cudnnRNNInputMode_t::LinearInput };
2868        check(unsafe {
2869            f(self.desc, algo, cell, bias_mode, dir, im,
2870              data_type.raw(), math_prec.raw(), math_type.raw(),
2871              input_size, hidden_size, proj_size, num_layers,
2872              dropout.desc, aux_flags)
2873        })
2874    }
2875
2876    /// Raw descriptor.
2877    #[inline]
2878    pub fn as_raw(&self) -> baracuda_cudnn_sys::cudnnRNNDescriptor_t { self.desc }
2879}
2880impl Drop for RnnDescriptor {
2881    fn drop(&mut self) {
2882        if let Ok(c) = cudnn() {
2883            if let Ok(cu) = c.cudnn_destroy_rnn_descriptor() {
2884                let _ = unsafe { cu(self.desc) };
2885            }
2886        }
2887    }
2888}
2889
2890/// Owned RNN-data descriptor used by the v8 RNN forward / backward path.
2891pub struct RnnDataDescriptor {
2892    desc: baracuda_cudnn_sys::cudnnRNNDataDescriptor_t,
2893}
2894unsafe impl Send for RnnDataDescriptor {}
2895impl core::fmt::Debug for RnnDataDescriptor {
2896    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2897        f.debug_struct("RnnDataDescriptor").field("desc", &self.desc).finish_non_exhaustive()
2898    }
2899}
2900impl RnnDataDescriptor {
2901    /// Allocate an empty RNN-data descriptor. Configure attributes through
2902    /// the raw cuDNN setter exposed via [`as_raw`](Self::as_raw).
2903    pub fn new() -> Result<Self> {
2904        let c = cudnn()?;
2905        let create = c.cudnn_create_rnn_data_descriptor()?;
2906        let mut desc: baracuda_cudnn_sys::cudnnRNNDataDescriptor_t = core::ptr::null_mut();
2907        check(unsafe { create(&mut desc) })?;
2908        Ok(Self { desc })
2909    }
2910    /// Raw descriptor.
2911    #[inline]
2912    pub fn as_raw(&self) -> baracuda_cudnn_sys::cudnnRNNDataDescriptor_t { self.desc }
2913}
2914impl Drop for RnnDataDescriptor {
2915    fn drop(&mut self) {
2916        if let Ok(c) = cudnn() {
2917            if let Ok(cu) = c.cudnn_destroy_rnn_data_descriptor() {
2918                let _ = unsafe { cu(self.desc) };
2919            }
2920        }
2921    }
2922}
2923
2924/// Finalize an RNN descriptor for a specific minibatch size.
2925pub fn build_rnn_dynamic(handle: &Handle, rnn: &RnnDescriptor, mini_batch: i32) -> Result<()> {
2926    let c = cudnn()?;
2927    let f = c.cudnn_build_rnn_dynamic()?;
2928    check(unsafe { f(handle.handle, rnn.desc, mini_batch) })
2929}
2930
2931/// Returns `(work_space_size, reserve_space_size)`.
2932/// `fwd_mode = 0` for inference, `1` for training.
2933pub fn rnn_temp_space_sizes(
2934    handle: &Handle, rnn: &RnnDescriptor, fwd_mode: i32, x: &RnnDataDescriptor,
2935) -> Result<(usize, usize)> {
2936    let c = cudnn()?;
2937    let f = c.cudnn_get_rnn_temp_space_sizes()?;
2938    let (mut ws, mut rs) = (0usize, 0usize);
2939    check(unsafe { f(handle.handle, rnn.desc, fwd_mode, x.desc, &mut ws, &mut rs) })?;
2940    Ok((ws, rs))
2941}
2942
2943/// Bytes the RNN's weight space needs.
2944pub fn rnn_weight_space_size(handle: &Handle, rnn: &RnnDescriptor) -> Result<usize> {
2945    let c = cudnn()?;
2946    let f = c.cudnn_get_rnn_weight_space_size()?;
2947    let mut size = 0usize;
2948    check(unsafe { f(handle.handle, rnn.desc, &mut size) })?;
2949    Ok(size)
2950}
2951
2952// ============================================================================
2953// Tier 5 — Multi-head attention
2954// ============================================================================
2955
2956/// Multi-head attention descriptor.
2957pub struct AttnDescriptor {
2958    desc: cudnnAttnDescriptor_t,
2959}
2960unsafe impl Send for AttnDescriptor {}
2961impl core::fmt::Debug for AttnDescriptor {
2962    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2963        f.debug_struct("AttnDescriptor").field("desc", &self.desc).finish_non_exhaustive()
2964    }
2965}
2966impl AttnDescriptor {
2967    /// Allocate an empty multi-head attention descriptor. Configure with
2968    /// [`set`](Self::set) before passing it to the attention forward / backward
2969    /// functions.
2970    pub fn new() -> Result<Self> {
2971        let c = cudnn()?;
2972        let cu = c.cudnn_create_attn_descriptor()?;
2973        let mut desc: cudnnAttnDescriptor_t = core::ptr::null_mut();
2974        check(unsafe { cu(&mut desc) })?;
2975        Ok(Self { desc })
2976    }
2977
2978    /// Configure the descriptor. See `cudnnSetAttnDescriptor` in the
2979    /// cuDNN reference for each parameter.
2980    #[allow(clippy::too_many_arguments)]
2981    pub fn set(
2982        &self,
2983        attn_mode: u32, n_heads: i32, sm_scaler: f64,
2984        data_type: DType, compute_prec: DType, math_type: MathType,
2985        attn_dropout: &DropoutDescriptor, post_dropout: &DropoutDescriptor,
2986        q_size: i32, k_size: i32, v_size: i32,
2987        q_proj_size: i32, k_proj_size: i32, v_proj_size: i32, o_proj_size: i32,
2988        qo_max_seq_length: i32, kv_max_seq_length: i32,
2989        max_batch_size: i32, max_beam_size: i32,
2990    ) -> Result<()> {
2991        let c = cudnn()?;
2992        let f = c.cudnn_set_attn_descriptor()?;
2993        check(unsafe {
2994            f(self.desc, attn_mode, n_heads, sm_scaler,
2995              data_type.raw(), compute_prec.raw(), math_type.raw(),
2996              attn_dropout.desc, post_dropout.desc,
2997              q_size, k_size, v_size,
2998              q_proj_size, k_proj_size, v_proj_size, o_proj_size,
2999              qo_max_seq_length, kv_max_seq_length,
3000              max_batch_size, max_beam_size)
3001        })
3002    }
3003
3004    /// Raw descriptor.
3005    #[inline]
3006    pub fn as_raw(&self) -> cudnnAttnDescriptor_t { self.desc }
3007}
3008impl Drop for AttnDescriptor {
3009    fn drop(&mut self) {
3010        if let Ok(c) = cudnn() {
3011            if let Ok(cu) = c.cudnn_destroy_attn_descriptor() {
3012                let _ = unsafe { cu(self.desc) };
3013            }
3014        }
3015    }
3016}
3017
3018/// Buffer requirements `(weights, work_space, reserve_space)`.
3019pub fn multi_head_attn_buffers(
3020    handle: &Handle, attn: &AttnDescriptor,
3021) -> Result<(usize, usize, usize)> {
3022    let c = cudnn()?;
3023    let f = c.cudnn_get_multi_head_attn_buffers()?;
3024    let (mut w, mut ws, mut rs) = (0usize, 0usize, 0usize);
3025    check(unsafe { f(handle.handle, attn.desc, &mut w, &mut ws, &mut rs) })?;
3026    Ok((w, ws, rs))
3027}
3028
3029/// Sequence-data descriptor used by multi-head attention.
3030pub struct SeqDataDescriptor {
3031    desc: cudnnSeqDataDescriptor_t,
3032}
3033unsafe impl Send for SeqDataDescriptor {}
3034impl core::fmt::Debug for SeqDataDescriptor {
3035    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
3036        f.debug_struct("SeqDataDescriptor").field("desc", &self.desc).finish_non_exhaustive()
3037    }
3038}
3039impl SeqDataDescriptor {
3040    /// Allocate an empty sequence-data descriptor. Configure with
3041    /// [`set`](Self::set) before passing it to a multi-head attention call.
3042    pub fn new() -> Result<Self> {
3043        let c = cudnn()?;
3044        let cu = c.cudnn_create_seq_data_descriptor()?;
3045        let mut desc: cudnnSeqDataDescriptor_t = core::ptr::null_mut();
3046        check(unsafe { cu(&mut desc) })?;
3047        Ok(Self { desc })
3048    }
3049
3050    /// # Safety
3051    /// `padding_fill` must point to a value of the descriptor's data type.
3052    #[allow(clippy::too_many_arguments)]
3053    pub unsafe fn set(
3054        &self,
3055        data_type: DType,
3056        dim_a: &[i32], axes: &[i32], seq_length_array: &[i32],
3057        padding_fill: *const core::ffi::c_void,
3058    ) -> Result<()> { unsafe {
3059        let c = cudnn()?;
3060        let f = c.cudnn_set_seq_data_descriptor()?;
3061        check(f(
3062            self.desc, data_type.raw(),
3063            dim_a.len() as core::ffi::c_int,
3064            dim_a.as_ptr(), axes.as_ptr(),
3065            seq_length_array.len(), seq_length_array.as_ptr(),
3066            padding_fill,
3067        ))
3068    }}
3069
3070    /// Raw descriptor.
3071    #[inline]
3072    pub fn as_raw(&self) -> cudnnSeqDataDescriptor_t { self.desc }
3073}
3074impl Drop for SeqDataDescriptor {
3075    fn drop(&mut self) {
3076        if let Ok(c) = cudnn() {
3077            if let Ok(cu) = c.cudnn_destroy_seq_data_descriptor() {
3078                let _ = unsafe { cu(self.desc) };
3079            }
3080        }
3081    }
3082}
3083
3084/// Re-exports for callers that want raw type access.
3085pub use baracuda_cudnn_sys::{cudnnMathType_t as RawMathType, cudnnReorderType_t as RawReorderType};
3086
3087// ============================================================================
3088// Tier 1 leftovers — 4-D descriptor readback + DropoutDescriptor get/restore
3089// ============================================================================
3090
3091impl TensorDescriptor {
3092    /// Strided 4-D constructor — per-axis strides instead of the
3093    /// row-major / channels-last layouts [`new_4d`](Self::new_4d) implies.
3094    #[allow(clippy::too_many_arguments)]
3095    pub fn new_4d_ex(
3096        dtype: DType,
3097        n: i32, c: i32, h: i32, w: i32,
3098        n_stride: i32, c_stride: i32, h_stride: i32, w_stride: i32,
3099    ) -> Result<Self> {
3100        let cu = cudnn()?;
3101        let create = cu.cudnn_create_tensor_descriptor()?;
3102        let set = cu.cudnn_set_tensor_4d_descriptor_ex()?;
3103        let mut desc: cudnnTensorDescriptor_t = core::ptr::null_mut();
3104        check(unsafe { create(&mut desc) })?;
3105        let this = Self { desc };
3106        check(unsafe {
3107            set(this.desc, dtype.raw(), n, c, h, w,
3108                n_stride, c_stride, h_stride, w_stride)
3109        })?;
3110        Ok(this)
3111    }
3112
3113    /// Read the 4-D parameters back out: `(dtype, n, c, h, w, n_stride,
3114    /// c_stride, h_stride, w_stride)`.
3115    #[allow(clippy::type_complexity)]
3116    pub fn get_4d(&self) -> Result<(DType, i32, i32, i32, i32, i32, i32, i32, i32)> {
3117        let cu = cudnn()?;
3118        let f = cu.cudnn_get_tensor_4d_descriptor()?;
3119        let mut dt = cudnnDataType_t::Float;
3120        let (mut n, mut c, mut h, mut w) = (0i32, 0i32, 0i32, 0i32);
3121        let (mut ns, mut cs, mut hs, mut ws) = (0i32, 0i32, 0i32, 0i32);
3122        check(unsafe {
3123            f(self.desc, &mut dt, &mut n, &mut c, &mut h, &mut w,
3124              &mut ns, &mut cs, &mut hs, &mut ws)
3125        })?;
3126        let dtype = match dt {
3127            cudnnDataType_t::Float => DType::F32,
3128            cudnnDataType_t::Double => DType::F64,
3129            cudnnDataType_t::Half => DType::F16,
3130            cudnnDataType_t::BFloat16 => DType::BF16,
3131            cudnnDataType_t::Int8 => DType::I8,
3132            cudnnDataType_t::Int32 => DType::I32,
3133            _ => DType::F32,
3134        };
3135        Ok((dtype, n, c, h, w, ns, cs, hs, ws))
3136    }
3137}
3138
3139impl FilterDescriptor {
3140    /// Read 4-D filter parameters: `(dtype, format, k, c, h, w)`.
3141    pub fn get_4d(&self) -> Result<(DType, TensorFormat, i32, i32, i32, i32)> {
3142        let cu = cudnn()?;
3143        let f = cu.cudnn_get_filter_4d_descriptor()?;
3144        let mut dt = cudnnDataType_t::Float;
3145        let mut fmt = cudnnTensorFormat_t::Nchw;
3146        let (mut k, mut c, mut h, mut w) = (0i32, 0i32, 0i32, 0i32);
3147        check(unsafe {
3148            f(self.desc, &mut dt, &mut fmt, &mut k, &mut c, &mut h, &mut w)
3149        })?;
3150        let dtype = match dt {
3151            cudnnDataType_t::Float => DType::F32,
3152            cudnnDataType_t::Double => DType::F64,
3153            cudnnDataType_t::Half => DType::F16,
3154            cudnnDataType_t::BFloat16 => DType::BF16,
3155            cudnnDataType_t::Int8 => DType::I8,
3156            cudnnDataType_t::Int32 => DType::I32,
3157            _ => DType::F32,
3158        };
3159        let format = match fmt {
3160            cudnnTensorFormat_t::Nchw => TensorFormat::Nchw,
3161            cudnnTensorFormat_t::Nhwc => TensorFormat::Nhwc,
3162            _ => TensorFormat::Nchw,
3163        };
3164        Ok((dtype, format, k, c, h, w))
3165    }
3166}
3167
3168impl DropoutDescriptor {
3169    /// Read a dropout descriptor's current state. Returns `(dropout_p,
3170    /// states_ptr, seed)`. The states pointer is owned by cuDNN.
3171    pub fn get(&self, handle: &Handle) -> Result<(f32, *mut core::ffi::c_void, u64)> {
3172        let cu = cudnn()?;
3173        let f = cu.cudnn_get_dropout_descriptor()?;
3174        let mut dropout: f32 = 0.0;
3175        let mut states: *mut core::ffi::c_void = core::ptr::null_mut();
3176        let mut seed: u64 = 0;
3177        check(unsafe { f(self.desc, handle.handle, &mut dropout, &mut states, &mut seed) })?;
3178        Ok((dropout, states, seed))
3179    }
3180
3181    /// Reattach a previously-saved RNG state buffer to this descriptor.
3182    /// Useful for reproducible eval / resume.
3183    ///
3184    /// # Safety
3185    /// `states` must be a buffer of at least [`dropout_states_size`] bytes
3186    /// from the same `handle`, valid for the descriptor's lifetime.
3187    pub unsafe fn restore(
3188        &self, handle: &Handle, dropout: f32,
3189        states: *mut core::ffi::c_void, state_size: usize, seed: u64,
3190    ) -> Result<()> { unsafe {
3191        let cu = cudnn()?;
3192        let f = cu.cudnn_restore_dropout_descriptor()?;
3193        check(f(self.desc, handle.handle, dropout, states, state_size, seed))
3194    }}
3195}
3196
3197// ============================================================================
3198// BatchNormalization "Ex" — actual forward/backward (workspace queries are
3199// already above).
3200// ============================================================================
3201
3202/// BN training forward with optional fused activation / residual add.
3203#[allow(clippy::too_many_arguments)]
3204pub fn batch_normalization_forward_training_ex<T: DeviceRepr>(
3205    handle: &Handle,
3206    mode: BatchNormMode, bn_ops: BnOp,
3207    alpha: f32, beta: f32,
3208    x_desc: &TensorDescriptor, x: &DeviceBuffer<T>,
3209    z_desc: &TensorDescriptor, z: &DeviceBuffer<T>,
3210    y_desc: &TensorDescriptor, y: &mut DeviceBuffer<T>,
3211    bn_smbv_desc: &TensorDescriptor,
3212    bn_scale: &DeviceBuffer<T>, bn_bias: &DeviceBuffer<T>,
3213    exponential_avg_factor: f64,
3214    running_mean: &mut DeviceBuffer<T>, running_var: &mut DeviceBuffer<T>,
3215    epsilon: f64,
3216    saved_mean: &mut DeviceBuffer<T>, saved_inv_var: &mut DeviceBuffer<T>,
3217    activation: &ActivationDescriptor,
3218    workspace: &mut DeviceBuffer<u8>, reserve: &mut DeviceBuffer<u8>,
3219) -> Result<()> {
3220    let c = cudnn()?;
3221    let cu = c.cudnn_batch_normalization_forward_training_ex()?;
3222    check(unsafe {
3223        cu(
3224            handle.handle, mode.raw(), bn_ops.raw(),
3225            &alpha as *const f32 as *const core::ffi::c_void,
3226            &beta as *const f32 as *const core::ffi::c_void,
3227            x_desc.desc, x.as_raw().0 as *const core::ffi::c_void,
3228            z_desc.desc, z.as_raw().0 as *const core::ffi::c_void,
3229            y_desc.desc, y.as_raw().0 as *mut core::ffi::c_void,
3230            bn_smbv_desc.desc,
3231            bn_scale.as_raw().0 as *const core::ffi::c_void,
3232            bn_bias.as_raw().0 as *const core::ffi::c_void,
3233            exponential_avg_factor,
3234            running_mean.as_raw().0 as *mut core::ffi::c_void,
3235            running_var.as_raw().0 as *mut core::ffi::c_void,
3236            epsilon,
3237            saved_mean.as_raw().0 as *mut core::ffi::c_void,
3238            saved_inv_var.as_raw().0 as *mut core::ffi::c_void,
3239            activation.desc,
3240            workspace.as_raw().0 as *mut core::ffi::c_void, workspace.byte_size(),
3241            reserve.as_raw().0 as *mut core::ffi::c_void, reserve.byte_size(),
3242        )
3243    })
3244}
3245
3246/// BN backward matching [`batch_normalization_forward_training_ex`].
3247#[allow(clippy::too_many_arguments)]
3248pub fn batch_normalization_backward_ex<T: DeviceRepr>(
3249    handle: &Handle,
3250    mode: BatchNormMode, bn_ops: BnOp,
3251    alpha_data: f32, beta_data: f32,
3252    alpha_param: f32, beta_param: f32,
3253    x_desc: &TensorDescriptor, x: &DeviceBuffer<T>,
3254    y_desc: &TensorDescriptor, y: &DeviceBuffer<T>,
3255    dy_desc: &TensorDescriptor, dy: &DeviceBuffer<T>,
3256    dz_desc: &TensorDescriptor, dz: &mut DeviceBuffer<T>,
3257    dx_desc: &TensorDescriptor, dx: &mut DeviceBuffer<T>,
3258    d_bn_scale_bias_desc: &TensorDescriptor,
3259    bn_scale: &DeviceBuffer<T>, bn_bias: &DeviceBuffer<T>,
3260    d_bn_scale: &mut DeviceBuffer<T>, d_bn_bias: &mut DeviceBuffer<T>,
3261    epsilon: f64,
3262    saved_mean: &DeviceBuffer<T>, saved_inv_var: &DeviceBuffer<T>,
3263    activation: &ActivationDescriptor,
3264    workspace: &mut DeviceBuffer<u8>, reserve: &mut DeviceBuffer<u8>,
3265) -> Result<()> {
3266    let c = cudnn()?;
3267    let cu = c.cudnn_batch_normalization_backward_ex()?;
3268    check(unsafe {
3269        cu(
3270            handle.handle, mode.raw(), bn_ops.raw(),
3271            &alpha_data as *const f32 as *const core::ffi::c_void,
3272            &beta_data as *const f32 as *const core::ffi::c_void,
3273            &alpha_param as *const f32 as *const core::ffi::c_void,
3274            &beta_param as *const f32 as *const core::ffi::c_void,
3275            x_desc.desc, x.as_raw().0 as *const core::ffi::c_void,
3276            y_desc.desc, y.as_raw().0 as *const core::ffi::c_void,
3277            dy_desc.desc, dy.as_raw().0 as *const core::ffi::c_void,
3278            dz_desc.desc, dz.as_raw().0 as *mut core::ffi::c_void,
3279            dx_desc.desc, dx.as_raw().0 as *mut core::ffi::c_void,
3280            d_bn_scale_bias_desc.desc,
3281            bn_scale.as_raw().0 as *const core::ffi::c_void,
3282            bn_bias.as_raw().0 as *const core::ffi::c_void,
3283            d_bn_scale.as_raw().0 as *mut core::ffi::c_void,
3284            d_bn_bias.as_raw().0 as *mut core::ffi::c_void,
3285            epsilon,
3286            saved_mean.as_raw().0 as *const core::ffi::c_void,
3287            saved_inv_var.as_raw().0 as *const core::ffi::c_void,
3288            activation.desc,
3289            workspace.as_raw().0 as *mut core::ffi::c_void, workspace.byte_size(),
3290            reserve.as_raw().0 as *mut core::ffi::c_void, reserve.byte_size(),
3291        )
3292    })
3293}
3294
3295// ============================================================================
3296// Tier 3 - Generic Normalization API ops (cuDNN 8+)
3297// ============================================================================
3298
3299/// Inference-time generic normalization.
3300#[allow(clippy::too_many_arguments)]
3301pub fn normalization_forward_inference<T: DeviceRepr>(
3302    handle: &Handle,
3303    mode: NormMode, ops: NormOp, algo: NormAlgo,
3304    alpha: f32, beta: f32,
3305    x_desc: &TensorDescriptor, x: &DeviceBuffer<T>,
3306    norm_scale_bias_desc: &TensorDescriptor,
3307    norm_scale: &DeviceBuffer<T>, norm_bias: &DeviceBuffer<T>,
3308    norm_mean_var_desc: &TensorDescriptor,
3309    estimated_mean: &DeviceBuffer<T>, estimated_var: &DeviceBuffer<T>,
3310    z_desc: &TensorDescriptor, z: &DeviceBuffer<T>,
3311    activation: &ActivationDescriptor,
3312    y_desc: &TensorDescriptor, y: &mut DeviceBuffer<T>,
3313    epsilon: f64, group_count: i32,
3314) -> Result<()> {
3315    let c = cudnn()?;
3316    let cu = c.cudnn_normalization_forward_inference()?;
3317    check(unsafe {
3318        cu(
3319            handle.handle, mode.raw(), ops.raw(), algo.raw(),
3320            &alpha as *const f32 as *const core::ffi::c_void,
3321            &beta as *const f32 as *const core::ffi::c_void,
3322            x_desc.desc, x.as_raw().0 as *const core::ffi::c_void,
3323            norm_scale_bias_desc.desc,
3324            norm_scale.as_raw().0 as *const core::ffi::c_void,
3325            norm_bias.as_raw().0 as *const core::ffi::c_void,
3326            norm_mean_var_desc.desc,
3327            estimated_mean.as_raw().0 as *const core::ffi::c_void,
3328            estimated_var.as_raw().0 as *const core::ffi::c_void,
3329            z_desc.desc, z.as_raw().0 as *const core::ffi::c_void,
3330            activation.desc,
3331            y_desc.desc, y.as_raw().0 as *mut core::ffi::c_void,
3332            epsilon, group_count,
3333        )
3334    })
3335}
3336
3337/// Workspace bytes for normalization_forward_training.
3338#[allow(clippy::too_many_arguments)]
3339pub fn normalization_forward_training_workspace_size(
3340    handle: &Handle,
3341    mode: NormMode, ops: NormOp, algo: NormAlgo,
3342    x_desc: &TensorDescriptor, z_desc: &TensorDescriptor,
3343    y_desc: &TensorDescriptor, norm_scale_bias_desc: &TensorDescriptor,
3344    activation: &ActivationDescriptor, norm_mean_var_desc: &TensorDescriptor,
3345    group_count: i32,
3346) -> Result<usize> {
3347    let c = cudnn()?;
3348    let f = c.cudnn_get_normalization_forward_training_workspace_size()?;
3349    let mut size = 0usize;
3350    check(unsafe {
3351        f(handle.handle, mode.raw(), ops.raw(), algo.raw(),
3352          x_desc.desc, z_desc.desc, y_desc.desc, norm_scale_bias_desc.desc,
3353          activation.desc, norm_mean_var_desc.desc, &mut size, group_count)
3354    })?;
3355    Ok(size)
3356}
3357
3358/// Workspace bytes for normalization_backward.
3359#[allow(clippy::too_many_arguments)]
3360pub fn normalization_backward_workspace_size(
3361    handle: &Handle,
3362    mode: NormMode, ops: NormOp, algo: NormAlgo,
3363    x_desc: &TensorDescriptor, y_desc: &TensorDescriptor,
3364    dy_desc: &TensorDescriptor, dz_desc: &TensorDescriptor,
3365    dx_desc: &TensorDescriptor, d_norm_scale_bias_desc: &TensorDescriptor,
3366    activation: &ActivationDescriptor, norm_mean_var_desc: &TensorDescriptor,
3367    group_count: i32,
3368) -> Result<usize> {
3369    let c = cudnn()?;
3370    let f = c.cudnn_get_normalization_backward_workspace_size()?;
3371    let mut size = 0usize;
3372    check(unsafe {
3373        f(handle.handle, mode.raw(), ops.raw(), algo.raw(),
3374          x_desc.desc, y_desc.desc, dy_desc.desc, dz_desc.desc,
3375          dx_desc.desc, d_norm_scale_bias_desc.desc,
3376          activation.desc, norm_mean_var_desc.desc, &mut size, group_count)
3377    })?;
3378    Ok(size)
3379}
3380
3381/// Reserve-space bytes for the training fwd/bwd pair.
3382pub fn normalization_training_reserve_space_size(
3383    handle: &Handle,
3384    mode: NormMode, ops: NormOp, algo: NormAlgo,
3385    activation: &ActivationDescriptor, x_desc: &TensorDescriptor,
3386    group_count: i32,
3387) -> Result<usize> {
3388    let c = cudnn()?;
3389    let f = c.cudnn_get_normalization_training_reserve_space_size()?;
3390    let mut size = 0usize;
3391    check(unsafe {
3392        f(handle.handle, mode.raw(), ops.raw(), algo.raw(),
3393          activation.desc, x_desc.desc, &mut size, group_count)
3394    })?;
3395    Ok(size)
3396}
3397
3398/// Training-time forward generic normalization.
3399#[allow(clippy::too_many_arguments)]
3400pub fn normalization_forward_training<T: DeviceRepr>(
3401    handle: &Handle,
3402    mode: NormMode, ops: NormOp, algo: NormAlgo,
3403    alpha: f32, beta: f32,
3404    x_desc: &TensorDescriptor, x: &DeviceBuffer<T>,
3405    norm_scale_bias_desc: &TensorDescriptor,
3406    norm_scale: &DeviceBuffer<T>, norm_bias: &DeviceBuffer<T>,
3407    exponential_avg_factor: f64,
3408    norm_mean_var_desc: &TensorDescriptor,
3409    running_mean: &mut DeviceBuffer<T>, running_var: &mut DeviceBuffer<T>,
3410    epsilon: f64,
3411    saved_mean: &mut DeviceBuffer<T>, saved_inv_var: &mut DeviceBuffer<T>,
3412    activation: &ActivationDescriptor,
3413    z_desc: &TensorDescriptor, z: &DeviceBuffer<T>,
3414    y_desc: &TensorDescriptor, y: &mut DeviceBuffer<T>,
3415    workspace: &mut DeviceBuffer<u8>, reserve: &mut DeviceBuffer<u8>,
3416    group_count: i32,
3417) -> Result<()> {
3418    let c = cudnn()?;
3419    let cu = c.cudnn_normalization_forward_training()?;
3420    check(unsafe {
3421        cu(
3422            handle.handle, mode.raw(), ops.raw(), algo.raw(),
3423            &alpha as *const f32 as *const core::ffi::c_void,
3424            &beta as *const f32 as *const core::ffi::c_void,
3425            x_desc.desc, x.as_raw().0 as *const core::ffi::c_void,
3426            norm_scale_bias_desc.desc,
3427            norm_scale.as_raw().0 as *const core::ffi::c_void,
3428            norm_bias.as_raw().0 as *const core::ffi::c_void,
3429            exponential_avg_factor,
3430            norm_mean_var_desc.desc,
3431            running_mean.as_raw().0 as *mut core::ffi::c_void,
3432            running_var.as_raw().0 as *mut core::ffi::c_void,
3433            epsilon,
3434            saved_mean.as_raw().0 as *mut core::ffi::c_void,
3435            saved_inv_var.as_raw().0 as *mut core::ffi::c_void,
3436            activation.desc,
3437            z_desc.desc, z.as_raw().0 as *const core::ffi::c_void,
3438            y_desc.desc, y.as_raw().0 as *mut core::ffi::c_void,
3439            workspace.as_raw().0 as *mut core::ffi::c_void, workspace.byte_size(),
3440            reserve.as_raw().0 as *mut core::ffi::c_void, reserve.byte_size(),
3441            group_count,
3442        )
3443    })
3444}
3445
3446/// Backward generic normalization.
3447#[allow(clippy::too_many_arguments)]
3448pub fn normalization_backward<T: DeviceRepr>(
3449    handle: &Handle,
3450    mode: NormMode, ops: NormOp, algo: NormAlgo,
3451    alpha_data: f32, beta_data: f32,
3452    alpha_param: f32, beta_param: f32,
3453    x_desc: &TensorDescriptor, x: &DeviceBuffer<T>,
3454    y_desc: &TensorDescriptor, y: &DeviceBuffer<T>,
3455    dy_desc: &TensorDescriptor, dy: &DeviceBuffer<T>,
3456    dz_desc: &TensorDescriptor, dz: &mut DeviceBuffer<T>,
3457    dx_desc: &TensorDescriptor, dx: &mut DeviceBuffer<T>,
3458    d_norm_scale_bias_desc: &TensorDescriptor,
3459    norm_scale: &DeviceBuffer<T>, norm_bias: &DeviceBuffer<T>,
3460    d_norm_scale: &mut DeviceBuffer<T>, d_norm_bias: &mut DeviceBuffer<T>,
3461    epsilon: f64,
3462    norm_mean_var_desc: &TensorDescriptor,
3463    saved_mean: &DeviceBuffer<T>, saved_inv_var: &DeviceBuffer<T>,
3464    activation: &ActivationDescriptor,
3465    workspace: &mut DeviceBuffer<u8>, reserve: &mut DeviceBuffer<u8>,
3466    group_count: i32,
3467) -> Result<()> {
3468    let c = cudnn()?;
3469    let cu = c.cudnn_normalization_backward()?;
3470    check(unsafe {
3471        cu(
3472            handle.handle, mode.raw(), ops.raw(), algo.raw(),
3473            &alpha_data as *const f32 as *const core::ffi::c_void,
3474            &beta_data as *const f32 as *const core::ffi::c_void,
3475            &alpha_param as *const f32 as *const core::ffi::c_void,
3476            &beta_param as *const f32 as *const core::ffi::c_void,
3477            x_desc.desc, x.as_raw().0 as *const core::ffi::c_void,
3478            y_desc.desc, y.as_raw().0 as *const core::ffi::c_void,
3479            dy_desc.desc, dy.as_raw().0 as *const core::ffi::c_void,
3480            dz_desc.desc, dz.as_raw().0 as *mut core::ffi::c_void,
3481            dx_desc.desc, dx.as_raw().0 as *mut core::ffi::c_void,
3482            d_norm_scale_bias_desc.desc,
3483            norm_scale.as_raw().0 as *const core::ffi::c_void,
3484            norm_bias.as_raw().0 as *const core::ffi::c_void,
3485            d_norm_scale.as_raw().0 as *mut core::ffi::c_void,
3486            d_norm_bias.as_raw().0 as *mut core::ffi::c_void,
3487            epsilon,
3488            norm_mean_var_desc.desc,
3489            saved_mean.as_raw().0 as *const core::ffi::c_void,
3490            saved_inv_var.as_raw().0 as *const core::ffi::c_void,
3491            activation.desc,
3492            workspace.as_raw().0 as *mut core::ffi::c_void, workspace.byte_size(),
3493            reserve.as_raw().0 as *mut core::ffi::c_void, reserve.byte_size(),
3494            group_count,
3495        )
3496    })
3497}
3498
3499// ============================================================================
3500// Tier 5 - Multi-head attention ops (forward + backward + weights query)
3501// ============================================================================
3502
3503/// Look up the descriptor of one of the (Q/K/V/O) weight matrices inside
3504/// the packed weights buffer.
3505///
3506/// `w_kind` matches cuDNN's `cudnnMultiHeadAttnWeightKind_t`:
3507///   0 = Q weights, 1 = K weights, 2 = V weights, 3 = O weights,
3508///   4 = Q bias, 5 = K bias, 6 = V bias, 7 = O bias.
3509///
3510/// # Safety
3511/// `weights` must point at the multi-head attention weight buffer
3512/// produced from `multi_head_attn_buffers`.
3513#[allow(clippy::too_many_arguments)]
3514pub unsafe fn get_multi_head_attn_weights(
3515    handle: &Handle,
3516    attn: &AttnDescriptor,
3517    w_kind: i32,
3518    weight_size_in_bytes: usize,
3519    weights: *const core::ffi::c_void,
3520    w_desc: &TensorDescriptor,
3521) -> Result<*mut core::ffi::c_void> { unsafe {
3522    let c = cudnn()?;
3523    let f = c.cudnn_get_multi_head_attn_weights()?;
3524    let mut addr: *mut core::ffi::c_void = core::ptr::null_mut();
3525    check(f(
3526        handle.handle, attn.desc, w_kind, weight_size_in_bytes, weights,
3527        w_desc.desc, &mut addr,
3528    ))?;
3529    Ok(addr)
3530}}
3531
3532/// Forward multi-head attention. The huge parameter list mirrors cuDNN's
3533/// `cudnnMultiHeadAttnForward` exactly; see the cuDNN reference for the
3534/// meaning of each window / sequence-length array.
3535///
3536/// # Safety
3537/// All device buffers must satisfy the size and alignment requirements
3538/// reported by [`multi_head_attn_buffers`]. `lo_win_idx`/`hi_win_idx`
3539/// must be host arrays of length `qo_max_seq_length`.
3540#[allow(clippy::too_many_arguments)]
3541pub unsafe fn multi_head_attn_forward(
3542    handle: &Handle,
3543    attn: &AttnDescriptor,
3544    curr_idx: i32,
3545    lo_win_idx: &[i32],
3546    hi_win_idx: &[i32],
3547    dev_seq_lengths_qo: *const i32,
3548    dev_seq_lengths_kv: *const i32,
3549    q_desc: &SeqDataDescriptor, queries: *const core::ffi::c_void,
3550    residuals: *const core::ffi::c_void,
3551    k_desc: &SeqDataDescriptor, keys: *const core::ffi::c_void,
3552    v_desc: &SeqDataDescriptor, values: *const core::ffi::c_void,
3553    o_desc: &SeqDataDescriptor, out: *mut core::ffi::c_void,
3554    weights: &DeviceBuffer<u8>,
3555    work_space: &mut DeviceBuffer<u8>,
3556    reserve_space: &mut DeviceBuffer<u8>,
3557) -> Result<()> { unsafe {
3558    let c = cudnn()?;
3559    let f = c.cudnn_multi_head_attn_forward()?;
3560    check(f(
3561        handle.handle, attn.desc,
3562        curr_idx, lo_win_idx.as_ptr(), hi_win_idx.as_ptr(),
3563        dev_seq_lengths_qo, dev_seq_lengths_kv,
3564        q_desc.desc, queries, residuals,
3565        k_desc.desc, keys,
3566        v_desc.desc, values,
3567        o_desc.desc, out,
3568        weights.byte_size(), weights.as_raw().0 as *const core::ffi::c_void,
3569        work_space.byte_size(), work_space.as_raw().0 as *mut core::ffi::c_void,
3570        reserve_space.byte_size(), reserve_space.as_raw().0 as *mut core::ffi::c_void,
3571    ))
3572}}
3573
3574/// Multi-head attention backward — data path (gradients w.r.t. Q/K/V).
3575///
3576/// # Safety
3577/// Same buffer-sizing rules as [`multi_head_attn_forward`].
3578#[allow(clippy::too_many_arguments)]
3579pub unsafe fn multi_head_attn_backward_data(
3580    handle: &Handle,
3581    attn: &AttnDescriptor,
3582    lo_win_idx: &[i32],
3583    hi_win_idx: &[i32],
3584    dev_seq_lengths_dqdo: *const i32,
3585    dev_seq_lengths_dkdv: *const i32,
3586    do_desc: &SeqDataDescriptor, dout: *const core::ffi::c_void,
3587    dq_desc: &SeqDataDescriptor, dqueries: *mut core::ffi::c_void,
3588    queries: *const core::ffi::c_void,
3589    dk_desc: &SeqDataDescriptor, dkeys: *mut core::ffi::c_void,
3590    keys: *const core::ffi::c_void,
3591    dv_desc: &SeqDataDescriptor, dvalues: *mut core::ffi::c_void,
3592    values: *const core::ffi::c_void,
3593    weights: &DeviceBuffer<u8>,
3594    work_space: &mut DeviceBuffer<u8>,
3595    reserve_space: &mut DeviceBuffer<u8>,
3596) -> Result<()> { unsafe {
3597    let c = cudnn()?;
3598    let f = c.cudnn_multi_head_attn_backward_data()?;
3599    check(f(
3600        handle.handle, attn.desc,
3601        lo_win_idx.as_ptr(), hi_win_idx.as_ptr(),
3602        dev_seq_lengths_dqdo, dev_seq_lengths_dkdv,
3603        do_desc.desc, dout,
3604        dq_desc.desc, dqueries, queries,
3605        dk_desc.desc, dkeys, keys,
3606        dv_desc.desc, dvalues, values,
3607        weights.byte_size(), weights.as_raw().0 as *const core::ffi::c_void,
3608        work_space.byte_size(), work_space.as_raw().0 as *mut core::ffi::c_void,
3609        reserve_space.byte_size(), reserve_space.as_raw().0 as *mut core::ffi::c_void,
3610    ))
3611}}
3612
3613/// Multi-head attention backward — weights path (gradient w.r.t. Q/K/V/O
3614/// projection weights). Pass `add_grad = true` to accumulate into
3615/// `dweights` (typical for multi-step training).
3616///
3617/// # Safety
3618/// Same as [`multi_head_attn_forward`].
3619#[allow(clippy::too_many_arguments)]
3620pub unsafe fn multi_head_attn_backward_weights(
3621    handle: &Handle,
3622    attn: &AttnDescriptor,
3623    add_grad: bool,
3624    q_desc: &SeqDataDescriptor, queries: *const core::ffi::c_void,
3625    k_desc: &SeqDataDescriptor, keys: *const core::ffi::c_void,
3626    v_desc: &SeqDataDescriptor, values: *const core::ffi::c_void,
3627    do_desc: &SeqDataDescriptor, dout: *const core::ffi::c_void,
3628    weights: &DeviceBuffer<u8>,
3629    dweights: &mut DeviceBuffer<u8>,
3630    work_space: &mut DeviceBuffer<u8>,
3631    reserve_space: &mut DeviceBuffer<u8>,
3632) -> Result<()> { unsafe {
3633    let c = cudnn()?;
3634    let f = c.cudnn_multi_head_attn_backward_weights()?;
3635    check(f(
3636        handle.handle, attn.desc, add_grad as core::ffi::c_int,
3637        q_desc.desc, queries,
3638        k_desc.desc, keys,
3639        v_desc.desc, values,
3640        do_desc.desc, dout,
3641        weights.byte_size(), weights.as_raw().0 as *const core::ffi::c_void,
3642        dweights.as_raw().0 as *mut core::ffi::c_void,
3643        work_space.byte_size(), work_space.as_raw().0 as *mut core::ffi::c_void,
3644        reserve_space.byte_size(), reserve_space.as_raw().0 as *mut core::ffi::c_void,
3645    ))
3646}}
3647
3648// ============================================================================
3649// Tier 4 (cont.) - RNN v8 forward + backward (data + weights)
3650// ============================================================================
3651
3652/// Forward pass of an RNN built via [`RnnDescriptor::set_v8`] /
3653/// [`build_rnn_dynamic`]. Pass `fwd_mode = 0` for inference (no reserve
3654/// space writes), `1` for training.
3655///
3656/// `dev_seq_lengths` is a device array of length `batch_size` giving the
3657/// valid timestep count per sequence. `hx`/`cx` may be null for an
3658/// implicit zero initial state; `hy`/`cy` may be null if the caller does
3659/// not need the final state.
3660///
3661/// # Safety
3662/// Buffer sizes must match what [`rnn_temp_space_sizes`] /
3663/// [`rnn_weight_space_size`] reported.
3664#[allow(clippy::too_many_arguments)]
3665pub unsafe fn rnn_forward(
3666    handle: &Handle,
3667    rnn: &RnnDescriptor,
3668    fwd_mode: i32,
3669    dev_seq_lengths: *const i32,
3670    x_desc: &RnnDataDescriptor, x: *const core::ffi::c_void,
3671    y_desc: &RnnDataDescriptor, y: *mut core::ffi::c_void,
3672    h_desc: &TensorDescriptor,
3673    hx: *const core::ffi::c_void,
3674    hy: *mut core::ffi::c_void,
3675    c_desc: &TensorDescriptor,
3676    cx: *const core::ffi::c_void,
3677    cy: *mut core::ffi::c_void,
3678    weight_space: &DeviceBuffer<u8>,
3679    work_space: &mut DeviceBuffer<u8>,
3680    reserve_space: &mut DeviceBuffer<u8>,
3681) -> Result<()> { unsafe {
3682    let c = cudnn()?;
3683    let f = c.cudnn_rnn_forward()?;
3684    check(f(
3685        handle.handle, rnn.desc, fwd_mode, dev_seq_lengths,
3686        x_desc.desc, x,
3687        y_desc.desc, y,
3688        h_desc.desc, hx, hy,
3689        c_desc.desc, cx, cy,
3690        weight_space.byte_size(), weight_space.as_raw().0 as *const core::ffi::c_void,
3691        work_space.byte_size(),   work_space.as_raw().0 as *mut core::ffi::c_void,
3692        reserve_space.byte_size(), reserve_space.as_raw().0 as *mut core::ffi::c_void,
3693    ))
3694}}
3695
3696/// RNN backward — data path (gradients w.r.t. inputs and initial states).
3697///
3698/// # Safety
3699/// Same buffer-sizing rules as [`rnn_forward`]. The reserve space must
3700/// be the exact one populated during the matching training-mode
3701/// `rnn_forward`.
3702#[allow(clippy::too_many_arguments)]
3703pub unsafe fn rnn_backward_data_v8(
3704    handle: &Handle,
3705    rnn: &RnnDescriptor,
3706    dev_seq_lengths: *const i32,
3707    y_desc: &RnnDataDescriptor,
3708    y: *const core::ffi::c_void,
3709    dy: *const core::ffi::c_void,
3710    x_desc: &RnnDataDescriptor,
3711    dx: *mut core::ffi::c_void,
3712    h_desc: &TensorDescriptor,
3713    hx: *const core::ffi::c_void,
3714    dhy: *const core::ffi::c_void,
3715    dhx: *mut core::ffi::c_void,
3716    c_desc: &TensorDescriptor,
3717    cx: *const core::ffi::c_void,
3718    dcy: *const core::ffi::c_void,
3719    dcx: *mut core::ffi::c_void,
3720    weight_space: &DeviceBuffer<u8>,
3721    work_space: &mut DeviceBuffer<u8>,
3722    reserve_space: &mut DeviceBuffer<u8>,
3723) -> Result<()> { unsafe {
3724    let c = cudnn()?;
3725    let f = c.cudnn_rnn_backward_data_v8()?;
3726    check(f(
3727        handle.handle, rnn.desc, dev_seq_lengths,
3728        y_desc.desc, y, dy,
3729        x_desc.desc, dx,
3730        h_desc.desc, hx, dhy, dhx,
3731        c_desc.desc, cx, dcy, dcx,
3732        weight_space.byte_size(), weight_space.as_raw().0 as *const core::ffi::c_void,
3733        work_space.byte_size(),   work_space.as_raw().0 as *mut core::ffi::c_void,
3734        reserve_space.byte_size(), reserve_space.as_raw().0 as *mut core::ffi::c_void,
3735    ))
3736}}
3737
3738/// RNN backward — weights path (gradients w.r.t. the weight space).
3739/// `add_grad = true` accumulates into `dweight_space` (typical for
3740/// multi-step training); `false` overwrites.
3741///
3742/// # Safety
3743/// Same as [`rnn_forward`].
3744#[allow(clippy::too_many_arguments)]
3745pub unsafe fn rnn_backward_weights_v8(
3746    handle: &Handle,
3747    rnn: &RnnDescriptor,
3748    add_grad: bool,
3749    dev_seq_lengths: *const i32,
3750    x_desc: &RnnDataDescriptor, x: *const core::ffi::c_void,
3751    h_desc: &TensorDescriptor,  hx: *const core::ffi::c_void,
3752    y_desc: &RnnDataDescriptor, y: *const core::ffi::c_void,
3753    dweight_space: &mut DeviceBuffer<u8>,
3754    work_space: &mut DeviceBuffer<u8>,
3755    reserve_space: &mut DeviceBuffer<u8>,
3756) -> Result<()> { unsafe {
3757    let c = cudnn()?;
3758    let f = c.cudnn_rnn_backward_weights_v8()?;
3759    check(f(
3760        handle.handle, rnn.desc, add_grad as core::ffi::c_int, dev_seq_lengths,
3761        x_desc.desc, x,
3762        h_desc.desc, hx,
3763        y_desc.desc, y,
3764        dweight_space.byte_size(), dweight_space.as_raw().0 as *mut core::ffi::c_void,
3765        work_space.byte_size(),    work_space.as_raw().0 as *mut core::ffi::c_void,
3766        reserve_space.byte_size(), reserve_space.as_raw().0 as *mut core::ffi::c_void,
3767    ))
3768}}