Skip to main content

oxicuda_dnn/
types.rs

1//! Core DNN type definitions.
2//!
3//! Provides tensor descriptors ([`TensorDesc`], [`TensorDescMut`]),
4//! layout conventions ([`TensorLayout`]), activation functions
5//! ([`Activation`]), convolution parameters ([`ConvolutionDescriptor`]),
6//! and algorithm selection ([`ConvAlgorithm`]).
7
8use std::marker::PhantomData;
9
10use oxicuda_blas::GpuFloat;
11use oxicuda_driver::ffi::CUdeviceptr;
12use oxicuda_memory::DeviceBuffer;
13
14use crate::error::{DnnError, DnnResult};
15
16// ---------------------------------------------------------------------------
17// TensorLayout
18// ---------------------------------------------------------------------------
19
20/// Memory layout convention for multi-dimensional tensors.
21///
22/// The layout determines how logical indices map to linear memory offsets.
23/// NHWC layouts are generally preferred on modern NVIDIA GPUs because they
24/// enable Tensor Core utilisation, while NCHW is the traditional PyTorch
25/// default.
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27pub enum TensorLayout {
28    /// Batch, Channels, Height, Width -- PyTorch default.
29    Nchw,
30    /// Batch, Height, Width, Channels -- Tensor Core optimal.
31    Nhwc,
32    /// Batch, Channels, Depth, Height, Width -- 3-D volumetric.
33    Ncdhw,
34    /// Batch, Depth, Height, Width, Channels -- 3-D channels-last.
35    Ndhwc,
36    /// Generic row-major layout for 2-D tensors (matrices) and MoE intermediates.
37    RowMajor,
38}
39
40impl TensorLayout {
41    /// Returns the number of spatial dimensions implied by this layout.
42    #[inline]
43    #[must_use]
44    pub const fn spatial_dims(self) -> usize {
45        match self {
46            Self::Nchw | Self::Nhwc => 2,
47            Self::Ncdhw | Self::Ndhwc => 3,
48            Self::RowMajor => 0,
49        }
50    }
51
52    /// Returns the expected number of tensor dimensions (including N and C).
53    #[inline]
54    #[must_use]
55    pub const fn expected_ndim(self) -> usize {
56        match self {
57            Self::Nchw | Self::Nhwc => 4,
58            Self::Ncdhw | Self::Ndhwc => 5,
59            Self::RowMajor => 2,
60        }
61    }
62
63    /// Returns `true` if this layout places channels last (NHWC or NDHWC).
64    #[inline]
65    #[must_use]
66    pub const fn is_channels_last(self) -> bool {
67        matches!(self, Self::Nhwc | Self::Ndhwc)
68    }
69}
70
71// ---------------------------------------------------------------------------
72// Activation
73// ---------------------------------------------------------------------------
74
75/// Activation function types supported by DNN kernels.
76///
77/// These correspond to the most common activation functions used in deep
78/// learning. Fused activation (e.g. conv + bias + ReLU) avoids extra
79/// memory round-trips and is a key optimisation target.
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
81pub enum Activation {
82    /// Rectified Linear Unit: `max(0, x)`.
83    Relu,
84    /// Gaussian Error Linear Unit (exact): `x * Phi(x)`.
85    Gelu,
86    /// GELU approximated via tanh.
87    GeluTanh,
88    /// Sigmoid Linear Unit (SiLU / Swish): `x * sigmoid(x)`.
89    Silu,
90    /// Logistic sigmoid: `1 / (1 + exp(-x))`.
91    Sigmoid,
92    /// Hyperbolic tangent.
93    Tanh,
94    /// Identity (no activation applied).
95    None,
96}
97
98// ---------------------------------------------------------------------------
99// TensorDesc (immutable)
100// ---------------------------------------------------------------------------
101
102/// Immutable tensor descriptor binding a device pointer to shape metadata.
103///
104/// `TensorDesc` does **not** own the device memory; it merely borrows the
105/// raw pointer for the duration of a DNN operation.  The caller must ensure
106/// that the referenced [`DeviceBuffer`] outlives any computation that uses
107/// this descriptor.
108pub struct TensorDesc<T: GpuFloat> {
109    /// Raw device pointer to the first element.
110    pub ptr: CUdeviceptr,
111    /// Shape (one entry per dimension).
112    pub dims: Vec<u32>,
113    /// Strides (one entry per dimension, in **elements** not bytes).
114    pub strides: Vec<u32>,
115    /// Memory layout convention.
116    pub layout: TensorLayout,
117    _phantom: PhantomData<T>,
118}
119
120impl<T: GpuFloat> TensorDesc<T> {
121    /// Creates an NCHW tensor descriptor from a device buffer.
122    ///
123    /// # Errors
124    ///
125    /// Returns [`DnnError::InvalidDimension`] if any dimension is zero.
126    /// Returns [`DnnError::BufferTooSmall`] if the buffer cannot hold
127    /// `n * c * h * w` elements.
128    pub fn nchw(buf: &DeviceBuffer<T>, n: u32, c: u32, h: u32, w: u32) -> DnnResult<Self> {
129        Self::validate_dims(&[n, c, h, w])?;
130        let dims = vec![n, c, h, w];
131        let strides = nchw_strides(c, h, w);
132        let desc = Self {
133            ptr: buf.as_device_ptr(),
134            dims,
135            strides,
136            layout: TensorLayout::Nchw,
137            _phantom: PhantomData,
138        };
139        desc.validate_buffer_size(buf)?;
140        Ok(desc)
141    }
142
143    /// Creates an NHWC tensor descriptor from a device buffer.
144    ///
145    /// # Errors
146    ///
147    /// Same as [`nchw`](Self::nchw).
148    pub fn nhwc(buf: &DeviceBuffer<T>, n: u32, c: u32, h: u32, w: u32) -> DnnResult<Self> {
149        Self::validate_dims(&[n, c, h, w])?;
150        let dims = vec![n, c, h, w];
151        let strides = nhwc_strides(c, h, w);
152        let desc = Self {
153            ptr: buf.as_device_ptr(),
154            dims,
155            strides,
156            layout: TensorLayout::Nhwc,
157            _phantom: PhantomData,
158        };
159        desc.validate_buffer_size(buf)?;
160        Ok(desc)
161    }
162
163    /// Creates an NCDHW (3-D volumetric) tensor descriptor.
164    ///
165    /// # Errors
166    ///
167    /// Same as [`nchw`](Self::nchw).
168    pub fn ncdhw(buf: &DeviceBuffer<T>, n: u32, c: u32, d: u32, h: u32, w: u32) -> DnnResult<Self> {
169        Self::validate_dims(&[n, c, d, h, w])?;
170        let dims = vec![n, c, d, h, w];
171        let strides = vec![c * d * h * w, d * h * w, h * w, w, 1];
172        let desc = Self {
173            ptr: buf.as_device_ptr(),
174            dims,
175            strides,
176            layout: TensorLayout::Ncdhw,
177            _phantom: PhantomData,
178        };
179        desc.validate_buffer_size(buf)?;
180        Ok(desc)
181    }
182
183    /// Creates a 2-D matrix descriptor (rows x cols, row-major).
184    ///
185    /// # Errors
186    ///
187    /// Returns [`DnnError::InvalidDimension`] if either dimension is zero.
188    /// Returns [`DnnError::BufferTooSmall`] if the buffer is too small.
189    pub fn matrix(buf: &DeviceBuffer<T>, rows: u32, cols: u32) -> DnnResult<Self> {
190        Self::validate_dims(&[rows, cols])?;
191        let dims = vec![rows, cols];
192        let strides = vec![cols, 1];
193        let desc = Self {
194            ptr: buf.as_device_ptr(),
195            dims,
196            strides,
197            layout: TensorLayout::Nchw, // row-major, analogous to NCHW
198            _phantom: PhantomData,
199        };
200        desc.validate_buffer_size(buf)?;
201        Ok(desc)
202    }
203
204    /// Constructs a descriptor from raw components without buffer validation.
205    ///
206    /// The caller must ensure that `ptr` points to a valid device allocation
207    /// large enough for the described tensor.
208    pub fn from_raw(
209        ptr: CUdeviceptr,
210        dims: Vec<u32>,
211        strides: Vec<u32>,
212        layout: TensorLayout,
213    ) -> DnnResult<Self> {
214        if dims.len() != strides.len() {
215            return Err(DnnError::InvalidDimension(format!(
216                "dims length ({}) != strides length ({})",
217                dims.len(),
218                strides.len()
219            )));
220        }
221        if dims.is_empty() {
222            return Err(DnnError::InvalidDimension("empty dims".into()));
223        }
224        Ok(Self {
225            ptr,
226            dims,
227            strides,
228            layout,
229            _phantom: PhantomData,
230        })
231    }
232
233    /// Returns the total number of elements in the tensor.
234    #[inline]
235    #[must_use]
236    pub fn numel(&self) -> usize {
237        self.dims.iter().map(|&d| d as usize).product()
238    }
239
240    /// Returns the number of dimensions.
241    #[inline]
242    #[must_use]
243    pub fn ndim(&self) -> usize {
244        self.dims.len()
245    }
246
247    /// Validates that `buf` is large enough to hold this tensor.
248    ///
249    /// # Errors
250    ///
251    /// Returns [`DnnError::BufferTooSmall`] if the buffer has fewer elements
252    /// than [`numel`](Self::numel).
253    pub fn validate_buffer_size(&self, buf: &DeviceBuffer<T>) -> DnnResult<()> {
254        let required = self.numel() * T::SIZE;
255        let actual = buf.len() * T::SIZE;
256        if actual < required {
257            return Err(DnnError::BufferTooSmall {
258                expected: required,
259                actual,
260            });
261        }
262        Ok(())
263    }
264
265    /// Checks that no dimension is zero.
266    fn validate_dims(dims: &[u32]) -> DnnResult<()> {
267        for (i, &d) in dims.iter().enumerate() {
268            if d == 0 {
269                return Err(DnnError::InvalidDimension(format!("dimension {i} is zero")));
270            }
271        }
272        Ok(())
273    }
274}
275
276// ---------------------------------------------------------------------------
277// TensorDescMut (mutable output)
278// ---------------------------------------------------------------------------
279
280/// Mutable tensor descriptor for output buffers.
281///
282/// Identical to [`TensorDesc`] but signals that the referenced memory will
283/// be written to.  Having a separate type prevents accidentally aliasing an
284/// input and output tensor at the type level.
285pub struct TensorDescMut<T: GpuFloat> {
286    /// Raw device pointer to the first element (will be written).
287    pub ptr: CUdeviceptr,
288    /// Shape (one entry per dimension).
289    pub dims: Vec<u32>,
290    /// Strides (one entry per dimension, in elements).
291    pub strides: Vec<u32>,
292    /// Memory layout convention.
293    pub layout: TensorLayout,
294    _phantom: PhantomData<T>,
295}
296
297impl<T: GpuFloat> TensorDescMut<T> {
298    /// Creates a mutable NCHW tensor descriptor from a device buffer.
299    ///
300    /// # Errors
301    ///
302    /// Same validation as [`TensorDesc::nchw`].
303    pub fn nchw(buf: &mut DeviceBuffer<T>, n: u32, c: u32, h: u32, w: u32) -> DnnResult<Self> {
304        validate_dims_helper(&[n, c, h, w])?;
305        let numel = (n as usize) * (c as usize) * (h as usize) * (w as usize);
306        validate_buf_size::<T>(buf.len(), numel)?;
307        Ok(Self {
308            ptr: buf.as_device_ptr(),
309            dims: vec![n, c, h, w],
310            strides: nchw_strides(c, h, w),
311            layout: TensorLayout::Nchw,
312            _phantom: PhantomData,
313        })
314    }
315
316    /// Creates a mutable NHWC tensor descriptor from a device buffer.
317    ///
318    /// # Errors
319    ///
320    /// Same validation as [`TensorDesc::nhwc`].
321    pub fn nhwc(buf: &mut DeviceBuffer<T>, n: u32, c: u32, h: u32, w: u32) -> DnnResult<Self> {
322        validate_dims_helper(&[n, c, h, w])?;
323        let numel = (n as usize) * (c as usize) * (h as usize) * (w as usize);
324        validate_buf_size::<T>(buf.len(), numel)?;
325        Ok(Self {
326            ptr: buf.as_device_ptr(),
327            dims: vec![n, c, h, w],
328            strides: nhwc_strides(c, h, w),
329            layout: TensorLayout::Nhwc,
330            _phantom: PhantomData,
331        })
332    }
333
334    /// Creates a mutable 2-D matrix descriptor (rows x cols, row-major).
335    ///
336    /// # Errors
337    ///
338    /// Same validation as [`TensorDesc::matrix`].
339    pub fn matrix(buf: &mut DeviceBuffer<T>, rows: u32, cols: u32) -> DnnResult<Self> {
340        validate_dims_helper(&[rows, cols])?;
341        let numel = (rows as usize) * (cols as usize);
342        validate_buf_size::<T>(buf.len(), numel)?;
343        Ok(Self {
344            ptr: buf.as_device_ptr(),
345            dims: vec![rows, cols],
346            strides: vec![cols, 1],
347            layout: TensorLayout::Nchw,
348            _phantom: PhantomData,
349        })
350    }
351
352    /// Constructs a mutable descriptor from raw components.
353    pub fn from_raw(
354        ptr: CUdeviceptr,
355        dims: Vec<u32>,
356        strides: Vec<u32>,
357        layout: TensorLayout,
358    ) -> DnnResult<Self> {
359        if dims.len() != strides.len() {
360            return Err(DnnError::InvalidDimension(format!(
361                "dims length ({}) != strides length ({})",
362                dims.len(),
363                strides.len()
364            )));
365        }
366        if dims.is_empty() {
367            return Err(DnnError::InvalidDimension("empty dims".into()));
368        }
369        Ok(Self {
370            ptr,
371            dims,
372            strides,
373            layout,
374            _phantom: PhantomData,
375        })
376    }
377
378    /// Returns the total number of elements in the tensor.
379    #[inline]
380    #[must_use]
381    pub fn numel(&self) -> usize {
382        self.dims.iter().map(|&d| d as usize).product()
383    }
384
385    /// Returns the number of dimensions.
386    #[inline]
387    #[must_use]
388    pub fn ndim(&self) -> usize {
389        self.dims.len()
390    }
391
392    /// Borrows this mutable descriptor as an immutable [`TensorDesc`].
393    #[must_use]
394    pub fn as_immutable(&self) -> TensorDesc<T> {
395        TensorDesc {
396            ptr: self.ptr,
397            dims: self.dims.clone(),
398            strides: self.strides.clone(),
399            layout: self.layout,
400            _phantom: PhantomData,
401        }
402    }
403}
404
405// ---------------------------------------------------------------------------
406// ConvolutionDescriptor
407// ---------------------------------------------------------------------------
408
409/// Describes a convolution operation's hyper-parameters.
410///
411/// All vectors are indexed by spatial dimension (e.g. for 2-D convolutions
412/// they have length 2, for 3-D length 3).
413#[derive(Debug, Clone)]
414pub struct ConvolutionDescriptor {
415    /// Zero-padding applied to each spatial dimension (symmetric).
416    pub padding: Vec<u32>,
417    /// Stride of the convolution kernel in each spatial dimension.
418    pub stride: Vec<u32>,
419    /// Dilation factor in each spatial dimension.
420    pub dilation: Vec<u32>,
421    /// Number of groups for grouped/depthwise convolution.
422    pub groups: u32,
423}
424
425impl ConvolutionDescriptor {
426    /// Creates a standard 2-D convolution descriptor.
427    ///
428    /// # Errors
429    ///
430    /// Returns [`DnnError::InvalidArgument`] if stride or dilation contains
431    /// a zero value, or if groups is zero.
432    pub fn conv2d(
433        pad_h: u32,
434        pad_w: u32,
435        stride_h: u32,
436        stride_w: u32,
437        dilation_h: u32,
438        dilation_w: u32,
439        groups: u32,
440    ) -> DnnResult<Self> {
441        if stride_h == 0 || stride_w == 0 {
442            return Err(DnnError::InvalidArgument("stride must be non-zero".into()));
443        }
444        if dilation_h == 0 || dilation_w == 0 {
445            return Err(DnnError::InvalidArgument(
446                "dilation must be non-zero".into(),
447            ));
448        }
449        if groups == 0 {
450            return Err(DnnError::InvalidArgument("groups must be non-zero".into()));
451        }
452        Ok(Self {
453            padding: vec![pad_h, pad_w],
454            stride: vec![stride_h, stride_w],
455            dilation: vec![dilation_h, dilation_w],
456            groups,
457        })
458    }
459
460    /// Returns the number of spatial dimensions this descriptor covers.
461    #[inline]
462    #[must_use]
463    pub fn spatial_dims(&self) -> usize {
464        self.padding.len()
465    }
466
467    /// Computes the output spatial size for a single dimension.
468    ///
469    /// Formula: `floor((input + 2*pad - dilation*(kernel-1) - 1) / stride) + 1`
470    ///
471    /// # Errors
472    ///
473    /// Returns [`DnnError::InvalidDimension`] if the computation underflows
474    /// (i.e. the kernel is too large for the padded input).
475    pub fn output_size(
476        input: u32,
477        kernel: u32,
478        pad: u32,
479        stride: u32,
480        dilation: u32,
481    ) -> DnnResult<u32> {
482        let effective_kernel = dilation
483            .checked_mul(kernel.saturating_sub(1))
484            .and_then(|v| v.checked_add(1))
485            .ok_or_else(|| DnnError::InvalidDimension("effective kernel size overflow".into()))?;
486        let padded_input = input
487            .checked_add(2 * pad)
488            .ok_or_else(|| DnnError::InvalidDimension("padded input overflow".into()))?;
489        if padded_input < effective_kernel {
490            return Err(DnnError::InvalidDimension(format!(
491                "padded input ({padded_input}) < effective kernel ({effective_kernel})"
492            )));
493        }
494        Ok((padded_input - effective_kernel) / stride + 1)
495    }
496}
497
498// ---------------------------------------------------------------------------
499// ConvAlgorithm
500// ---------------------------------------------------------------------------
501
502/// Convolution algorithm selection.
503///
504/// Different algorithms offer different trade-offs between workspace memory
505/// and compute throughput.  The optimal choice depends on tensor sizes, GPU
506/// architecture, and available workspace.
507#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
508pub enum ConvAlgorithm {
509    /// Implicit GEMM -- low workspace, moderate speed.
510    ImplicitGemm,
511    /// Im2col followed by explicit GEMM -- higher workspace, often fastest
512    /// for medium-sized feature maps.
513    Im2colGemm,
514    /// Winograd transform -- fastest for 3x3 kernels with stride 1, but
515    /// requires workspace and may reduce numerical precision.
516    Winograd,
517    /// Direct convolution -- no workspace, straightforward nested loops.
518    Direct,
519    /// FFT-based convolution -- fastest for very large kernels.
520    FftConv,
521}
522
523// ---------------------------------------------------------------------------
524// TileConfig
525// ---------------------------------------------------------------------------
526
527/// Tile configuration for tiled convolution kernels.
528///
529/// Controls work decomposition across thread blocks and warps.
530#[derive(Debug, Clone, Copy)]
531pub struct TileConfig {
532    /// Tile size in the M dimension (output spatial points per block).
533    pub tile_m: u32,
534    /// Tile size in the N dimension (output channels per block).
535    pub tile_n: u32,
536    /// Tile size in the K dimension (reduction loop step).
537    pub tile_k: u32,
538    /// Warp-level tile in M.
539    pub warp_m: u32,
540    /// Warp-level tile in N.
541    pub warp_n: u32,
542    /// Number of software pipeline stages.
543    pub stages: u32,
544}
545
546impl TileConfig {
547    /// Returns a default tile configuration for the given SM version.
548    #[must_use]
549    pub fn default_conv(sm: oxicuda_ptx::arch::SmVersion) -> Self {
550        use oxicuda_ptx::arch::SmVersion;
551        match sm {
552            SmVersion::Sm90 | SmVersion::Sm90a | SmVersion::Sm100 | SmVersion::Sm120 => Self {
553                tile_m: 128,
554                tile_n: 128,
555                tile_k: 32,
556                warp_m: 64,
557                warp_n: 64,
558                stages: 4,
559            },
560            SmVersion::Sm80 | SmVersion::Sm86 | SmVersion::Sm89 => Self {
561                tile_m: 128,
562                tile_n: 128,
563                tile_k: 32,
564                warp_m: 64,
565                warp_n: 64,
566                stages: 3,
567            },
568            SmVersion::Sm75 => Self {
569                tile_m: 64,
570                tile_n: 64,
571                tile_k: 32,
572                warp_m: 32,
573                warp_n: 32,
574                stages: 2,
575            },
576        }
577    }
578}
579
580// ---------------------------------------------------------------------------
581// Helper: pool / resize output size
582// ---------------------------------------------------------------------------
583
584/// Computes the output spatial dimension for a pooling operation.
585///
586/// `output_dim = floor((input_dim + 2 * padding - kernel_size) / stride) + 1`
587///
588/// Returns `None` if the resulting dimension would be zero or negative.
589#[must_use]
590pub fn pool_output_size(
591    input_dim: u32,
592    kernel_size: u32,
593    stride: u32,
594    padding: u32,
595) -> Option<u32> {
596    if stride == 0 || kernel_size == 0 {
597        return None;
598    }
599    let effective = input_dim + 2 * padding;
600    if effective < kernel_size {
601        return None;
602    }
603    Some((effective - kernel_size) / stride + 1)
604}
605
606// ---------------------------------------------------------------------------
607// Stride helpers (private)
608// ---------------------------------------------------------------------------
609
610/// Computes NCHW strides: `[C*H*W, H*W, W, 1]`.
611fn nchw_strides(c: u32, h: u32, w: u32) -> Vec<u32> {
612    vec![c * h * w, h * w, w, 1]
613}
614
615/// Computes NHWC strides: `[H*W*C, 1, W*C, C]`.
616fn nhwc_strides(c: u32, h: u32, w: u32) -> Vec<u32> {
617    vec![h * w * c, 1, w * c, c]
618}
619
620/// Shared dimension validation.
621fn validate_dims_helper(dims: &[u32]) -> DnnResult<()> {
622    for (i, &d) in dims.iter().enumerate() {
623        if d == 0 {
624            return Err(DnnError::InvalidDimension(format!("dimension {i} is zero")));
625        }
626    }
627    Ok(())
628}
629
630/// Validates buffer size against required element count.
631fn validate_buf_size<T: GpuFloat>(buf_len: usize, required_numel: usize) -> DnnResult<()> {
632    let required = required_numel * T::SIZE;
633    let actual = buf_len * T::SIZE;
634    if actual < required {
635        return Err(DnnError::BufferTooSmall {
636            expected: required,
637            actual,
638        });
639    }
640    Ok(())
641}
642
643// ---------------------------------------------------------------------------
644// Tests
645// ---------------------------------------------------------------------------
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650
651    #[test]
652    fn nchw_stride_order() {
653        let s = nchw_strides(3, 4, 5);
654        assert_eq!(s, vec![60, 20, 5, 1]);
655    }
656
657    #[test]
658    fn nhwc_stride_order() {
659        let s = nhwc_strides(3, 4, 5);
660        // N-stride = H*W*C = 60, C-stride = 1, H-stride = W*C = 15, W-stride = C = 3
661        assert_eq!(s, vec![60, 1, 15, 3]);
662    }
663
664    #[test]
665    fn conv_output_size_basic() {
666        // 32x32 input, 3x3 kernel, pad=1, stride=1, dilation=1 => 32
667        let out = ConvolutionDescriptor::output_size(32, 3, 1, 1, 1);
668        assert_eq!(out.ok(), Some(32));
669    }
670
671    #[test]
672    fn conv_output_size_strided() {
673        // 32x32 input, 3x3 kernel, pad=1, stride=2 => 16
674        let out = ConvolutionDescriptor::output_size(32, 3, 1, 2, 1);
675        assert_eq!(out.ok(), Some(16));
676    }
677
678    #[test]
679    fn conv_output_size_dilated() {
680        // 32x32 input, 3x3 kernel, pad=2, stride=1, dilation=2 => 32
681        let out = ConvolutionDescriptor::output_size(32, 3, 2, 1, 2);
682        assert_eq!(out.ok(), Some(32));
683    }
684
685    #[test]
686    fn conv_output_size_too_small() {
687        let out = ConvolutionDescriptor::output_size(3, 5, 0, 1, 1);
688        assert!(out.is_err());
689    }
690
691    #[test]
692    fn conv2d_zero_stride_rejected() {
693        let r = ConvolutionDescriptor::conv2d(0, 0, 0, 1, 1, 1, 1);
694        assert!(r.is_err());
695    }
696
697    #[test]
698    fn conv2d_zero_groups_rejected() {
699        let r = ConvolutionDescriptor::conv2d(0, 0, 1, 1, 1, 1, 0);
700        assert!(r.is_err());
701    }
702
703    #[test]
704    fn tensor_layout_spatial_dims() {
705        assert_eq!(TensorLayout::Nchw.spatial_dims(), 2);
706        assert_eq!(TensorLayout::Nhwc.spatial_dims(), 2);
707        assert_eq!(TensorLayout::Ncdhw.spatial_dims(), 3);
708        assert_eq!(TensorLayout::Ndhwc.spatial_dims(), 3);
709    }
710
711    #[test]
712    fn tensor_layout_expected_ndim() {
713        assert_eq!(TensorLayout::Nchw.expected_ndim(), 4);
714        assert_eq!(TensorLayout::Ncdhw.expected_ndim(), 5);
715    }
716
717    #[test]
718    fn from_raw_mismatched_lengths() {
719        let r = TensorDesc::<f32>::from_raw(0, vec![1, 2], vec![1], TensorLayout::Nchw);
720        assert!(r.is_err());
721    }
722
723    #[test]
724    fn from_raw_empty_dims() {
725        let r = TensorDesc::<f32>::from_raw(0, vec![], vec![], TensorLayout::Nchw);
726        assert!(r.is_err());
727    }
728
729    #[test]
730    fn activation_variants_are_distinct() {
731        assert_ne!(Activation::Relu, Activation::Gelu);
732        assert_ne!(Activation::Gelu, Activation::GeluTanh);
733        assert_ne!(Activation::Silu, Activation::Sigmoid);
734        assert_eq!(Activation::None, Activation::None);
735    }
736
737    #[test]
738    fn conv_algorithm_debug() {
739        let _ = format!("{:?}", ConvAlgorithm::Winograd);
740    }
741
742    #[test]
743    fn pool_output_basic() {
744        assert_eq!(pool_output_size(4, 2, 2, 0), Some(2));
745        assert_eq!(pool_output_size(5, 3, 1, 1), Some(5));
746    }
747
748    #[test]
749    fn pool_output_zero_stride() {
750        assert_eq!(pool_output_size(4, 2, 0, 0), None);
751    }
752
753    #[test]
754    fn pool_output_kernel_too_large() {
755        assert_eq!(pool_output_size(2, 5, 1, 0), None);
756    }
757}