use std::marker::PhantomData;
use oxicuda_blas::GpuFloat;
use oxicuda_driver::ffi::CUdeviceptr;
use oxicuda_memory::DeviceBuffer;
use crate::error::{DnnError, DnnResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TensorLayout {
Nchw,
Nhwc,
Ncdhw,
Ndhwc,
RowMajor,
}
impl TensorLayout {
#[inline]
#[must_use]
pub const fn spatial_dims(self) -> usize {
match self {
Self::Nchw | Self::Nhwc => 2,
Self::Ncdhw | Self::Ndhwc => 3,
Self::RowMajor => 0,
}
}
#[inline]
#[must_use]
pub const fn expected_ndim(self) -> usize {
match self {
Self::Nchw | Self::Nhwc => 4,
Self::Ncdhw | Self::Ndhwc => 5,
Self::RowMajor => 2,
}
}
#[inline]
#[must_use]
pub const fn is_channels_last(self) -> bool {
matches!(self, Self::Nhwc | Self::Ndhwc)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Activation {
Relu,
Gelu,
GeluTanh,
Silu,
Sigmoid,
Tanh,
None,
}
pub struct TensorDesc<T: GpuFloat> {
pub ptr: CUdeviceptr,
pub dims: Vec<u32>,
pub strides: Vec<u32>,
pub layout: TensorLayout,
_phantom: PhantomData<T>,
}
impl<T: GpuFloat> TensorDesc<T> {
pub fn nchw(buf: &DeviceBuffer<T>, n: u32, c: u32, h: u32, w: u32) -> DnnResult<Self> {
Self::validate_dims(&[n, c, h, w])?;
let dims = vec![n, c, h, w];
let strides = nchw_strides(c, h, w);
let desc = Self {
ptr: buf.as_device_ptr(),
dims,
strides,
layout: TensorLayout::Nchw,
_phantom: PhantomData,
};
desc.validate_buffer_size(buf)?;
Ok(desc)
}
pub fn nhwc(buf: &DeviceBuffer<T>, n: u32, c: u32, h: u32, w: u32) -> DnnResult<Self> {
Self::validate_dims(&[n, c, h, w])?;
let dims = vec![n, c, h, w];
let strides = nhwc_strides(c, h, w);
let desc = Self {
ptr: buf.as_device_ptr(),
dims,
strides,
layout: TensorLayout::Nhwc,
_phantom: PhantomData,
};
desc.validate_buffer_size(buf)?;
Ok(desc)
}
pub fn ncdhw(buf: &DeviceBuffer<T>, n: u32, c: u32, d: u32, h: u32, w: u32) -> DnnResult<Self> {
Self::validate_dims(&[n, c, d, h, w])?;
let dims = vec![n, c, d, h, w];
let strides = vec![c * d * h * w, d * h * w, h * w, w, 1];
let desc = Self {
ptr: buf.as_device_ptr(),
dims,
strides,
layout: TensorLayout::Ncdhw,
_phantom: PhantomData,
};
desc.validate_buffer_size(buf)?;
Ok(desc)
}
pub fn matrix(buf: &DeviceBuffer<T>, rows: u32, cols: u32) -> DnnResult<Self> {
Self::validate_dims(&[rows, cols])?;
let dims = vec![rows, cols];
let strides = vec![cols, 1];
let desc = Self {
ptr: buf.as_device_ptr(),
dims,
strides,
layout: TensorLayout::Nchw, _phantom: PhantomData,
};
desc.validate_buffer_size(buf)?;
Ok(desc)
}
pub fn from_raw(
ptr: CUdeviceptr,
dims: Vec<u32>,
strides: Vec<u32>,
layout: TensorLayout,
) -> DnnResult<Self> {
if dims.len() != strides.len() {
return Err(DnnError::InvalidDimension(format!(
"dims length ({}) != strides length ({})",
dims.len(),
strides.len()
)));
}
if dims.is_empty() {
return Err(DnnError::InvalidDimension("empty dims".into()));
}
Ok(Self {
ptr,
dims,
strides,
layout,
_phantom: PhantomData,
})
}
#[inline]
#[must_use]
pub fn numel(&self) -> usize {
self.dims.iter().map(|&d| d as usize).product()
}
#[inline]
#[must_use]
pub fn ndim(&self) -> usize {
self.dims.len()
}
pub fn validate_buffer_size(&self, buf: &DeviceBuffer<T>) -> DnnResult<()> {
let required = self.numel() * T::SIZE;
let actual = buf.len() * T::SIZE;
if actual < required {
return Err(DnnError::BufferTooSmall {
expected: required,
actual,
});
}
Ok(())
}
fn validate_dims(dims: &[u32]) -> DnnResult<()> {
for (i, &d) in dims.iter().enumerate() {
if d == 0 {
return Err(DnnError::InvalidDimension(format!("dimension {i} is zero")));
}
}
Ok(())
}
}
pub struct TensorDescMut<T: GpuFloat> {
pub ptr: CUdeviceptr,
pub dims: Vec<u32>,
pub strides: Vec<u32>,
pub layout: TensorLayout,
_phantom: PhantomData<T>,
}
impl<T: GpuFloat> TensorDescMut<T> {
pub fn nchw(buf: &mut DeviceBuffer<T>, n: u32, c: u32, h: u32, w: u32) -> DnnResult<Self> {
validate_dims_helper(&[n, c, h, w])?;
let numel = (n as usize) * (c as usize) * (h as usize) * (w as usize);
validate_buf_size::<T>(buf.len(), numel)?;
Ok(Self {
ptr: buf.as_device_ptr(),
dims: vec![n, c, h, w],
strides: nchw_strides(c, h, w),
layout: TensorLayout::Nchw,
_phantom: PhantomData,
})
}
pub fn nhwc(buf: &mut DeviceBuffer<T>, n: u32, c: u32, h: u32, w: u32) -> DnnResult<Self> {
validate_dims_helper(&[n, c, h, w])?;
let numel = (n as usize) * (c as usize) * (h as usize) * (w as usize);
validate_buf_size::<T>(buf.len(), numel)?;
Ok(Self {
ptr: buf.as_device_ptr(),
dims: vec![n, c, h, w],
strides: nhwc_strides(c, h, w),
layout: TensorLayout::Nhwc,
_phantom: PhantomData,
})
}
pub fn matrix(buf: &mut DeviceBuffer<T>, rows: u32, cols: u32) -> DnnResult<Self> {
validate_dims_helper(&[rows, cols])?;
let numel = (rows as usize) * (cols as usize);
validate_buf_size::<T>(buf.len(), numel)?;
Ok(Self {
ptr: buf.as_device_ptr(),
dims: vec![rows, cols],
strides: vec![cols, 1],
layout: TensorLayout::Nchw,
_phantom: PhantomData,
})
}
pub fn from_raw(
ptr: CUdeviceptr,
dims: Vec<u32>,
strides: Vec<u32>,
layout: TensorLayout,
) -> DnnResult<Self> {
if dims.len() != strides.len() {
return Err(DnnError::InvalidDimension(format!(
"dims length ({}) != strides length ({})",
dims.len(),
strides.len()
)));
}
if dims.is_empty() {
return Err(DnnError::InvalidDimension("empty dims".into()));
}
Ok(Self {
ptr,
dims,
strides,
layout,
_phantom: PhantomData,
})
}
#[inline]
#[must_use]
pub fn numel(&self) -> usize {
self.dims.iter().map(|&d| d as usize).product()
}
#[inline]
#[must_use]
pub fn ndim(&self) -> usize {
self.dims.len()
}
#[must_use]
pub fn as_immutable(&self) -> TensorDesc<T> {
TensorDesc {
ptr: self.ptr,
dims: self.dims.clone(),
strides: self.strides.clone(),
layout: self.layout,
_phantom: PhantomData,
}
}
}
#[derive(Debug, Clone)]
pub struct ConvolutionDescriptor {
pub padding: Vec<u32>,
pub stride: Vec<u32>,
pub dilation: Vec<u32>,
pub groups: u32,
}
impl ConvolutionDescriptor {
pub fn conv2d(
pad_h: u32,
pad_w: u32,
stride_h: u32,
stride_w: u32,
dilation_h: u32,
dilation_w: u32,
groups: u32,
) -> DnnResult<Self> {
if stride_h == 0 || stride_w == 0 {
return Err(DnnError::InvalidArgument("stride must be non-zero".into()));
}
if dilation_h == 0 || dilation_w == 0 {
return Err(DnnError::InvalidArgument(
"dilation must be non-zero".into(),
));
}
if groups == 0 {
return Err(DnnError::InvalidArgument("groups must be non-zero".into()));
}
Ok(Self {
padding: vec![pad_h, pad_w],
stride: vec![stride_h, stride_w],
dilation: vec![dilation_h, dilation_w],
groups,
})
}
#[inline]
#[must_use]
pub fn spatial_dims(&self) -> usize {
self.padding.len()
}
pub fn output_size(
input: u32,
kernel: u32,
pad: u32,
stride: u32,
dilation: u32,
) -> DnnResult<u32> {
let effective_kernel = dilation
.checked_mul(kernel.saturating_sub(1))
.and_then(|v| v.checked_add(1))
.ok_or_else(|| DnnError::InvalidDimension("effective kernel size overflow".into()))?;
let padded_input = input
.checked_add(2 * pad)
.ok_or_else(|| DnnError::InvalidDimension("padded input overflow".into()))?;
if padded_input < effective_kernel {
return Err(DnnError::InvalidDimension(format!(
"padded input ({padded_input}) < effective kernel ({effective_kernel})"
)));
}
Ok((padded_input - effective_kernel) / stride + 1)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ConvAlgorithm {
ImplicitGemm,
Im2colGemm,
Winograd,
Direct,
FftConv,
}
#[derive(Debug, Clone, Copy)]
pub struct TileConfig {
pub tile_m: u32,
pub tile_n: u32,
pub tile_k: u32,
pub warp_m: u32,
pub warp_n: u32,
pub stages: u32,
}
impl TileConfig {
#[must_use]
pub fn default_conv(sm: oxicuda_ptx::arch::SmVersion) -> Self {
use oxicuda_ptx::arch::SmVersion;
match sm {
SmVersion::Sm90 | SmVersion::Sm90a | SmVersion::Sm100 | SmVersion::Sm120 => Self {
tile_m: 128,
tile_n: 128,
tile_k: 32,
warp_m: 64,
warp_n: 64,
stages: 4,
},
SmVersion::Sm80 | SmVersion::Sm86 | SmVersion::Sm89 => Self {
tile_m: 128,
tile_n: 128,
tile_k: 32,
warp_m: 64,
warp_n: 64,
stages: 3,
},
SmVersion::Sm75 => Self {
tile_m: 64,
tile_n: 64,
tile_k: 32,
warp_m: 32,
warp_n: 32,
stages: 2,
},
}
}
}
#[must_use]
pub fn pool_output_size(
input_dim: u32,
kernel_size: u32,
stride: u32,
padding: u32,
) -> Option<u32> {
if stride == 0 || kernel_size == 0 {
return None;
}
let effective = input_dim + 2 * padding;
if effective < kernel_size {
return None;
}
Some((effective - kernel_size) / stride + 1)
}
fn nchw_strides(c: u32, h: u32, w: u32) -> Vec<u32> {
vec![c * h * w, h * w, w, 1]
}
fn nhwc_strides(c: u32, h: u32, w: u32) -> Vec<u32> {
vec![h * w * c, 1, w * c, c]
}
fn validate_dims_helper(dims: &[u32]) -> DnnResult<()> {
for (i, &d) in dims.iter().enumerate() {
if d == 0 {
return Err(DnnError::InvalidDimension(format!("dimension {i} is zero")));
}
}
Ok(())
}
fn validate_buf_size<T: GpuFloat>(buf_len: usize, required_numel: usize) -> DnnResult<()> {
let required = required_numel * T::SIZE;
let actual = buf_len * T::SIZE;
if actual < required {
return Err(DnnError::BufferTooSmall {
expected: required,
actual,
});
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn nchw_stride_order() {
let s = nchw_strides(3, 4, 5);
assert_eq!(s, vec![60, 20, 5, 1]);
}
#[test]
fn nhwc_stride_order() {
let s = nhwc_strides(3, 4, 5);
assert_eq!(s, vec![60, 1, 15, 3]);
}
#[test]
fn conv_output_size_basic() {
let out = ConvolutionDescriptor::output_size(32, 3, 1, 1, 1);
assert_eq!(out.ok(), Some(32));
}
#[test]
fn conv_output_size_strided() {
let out = ConvolutionDescriptor::output_size(32, 3, 1, 2, 1);
assert_eq!(out.ok(), Some(16));
}
#[test]
fn conv_output_size_dilated() {
let out = ConvolutionDescriptor::output_size(32, 3, 2, 1, 2);
assert_eq!(out.ok(), Some(32));
}
#[test]
fn conv_output_size_too_small() {
let out = ConvolutionDescriptor::output_size(3, 5, 0, 1, 1);
assert!(out.is_err());
}
#[test]
fn conv2d_zero_stride_rejected() {
let r = ConvolutionDescriptor::conv2d(0, 0, 0, 1, 1, 1, 1);
assert!(r.is_err());
}
#[test]
fn conv2d_zero_groups_rejected() {
let r = ConvolutionDescriptor::conv2d(0, 0, 1, 1, 1, 1, 0);
assert!(r.is_err());
}
#[test]
fn tensor_layout_spatial_dims() {
assert_eq!(TensorLayout::Nchw.spatial_dims(), 2);
assert_eq!(TensorLayout::Nhwc.spatial_dims(), 2);
assert_eq!(TensorLayout::Ncdhw.spatial_dims(), 3);
assert_eq!(TensorLayout::Ndhwc.spatial_dims(), 3);
}
#[test]
fn tensor_layout_expected_ndim() {
assert_eq!(TensorLayout::Nchw.expected_ndim(), 4);
assert_eq!(TensorLayout::Ncdhw.expected_ndim(), 5);
}
#[test]
fn from_raw_mismatched_lengths() {
let r = TensorDesc::<f32>::from_raw(0, vec![1, 2], vec![1], TensorLayout::Nchw);
assert!(r.is_err());
}
#[test]
fn from_raw_empty_dims() {
let r = TensorDesc::<f32>::from_raw(0, vec![], vec![], TensorLayout::Nchw);
assert!(r.is_err());
}
#[test]
fn activation_variants_are_distinct() {
assert_ne!(Activation::Relu, Activation::Gelu);
assert_ne!(Activation::Gelu, Activation::GeluTanh);
assert_ne!(Activation::Silu, Activation::Sigmoid);
assert_eq!(Activation::None, Activation::None);
}
#[test]
fn conv_algorithm_debug() {
let _ = format!("{:?}", ConvAlgorithm::Winograd);
}
#[test]
fn pool_output_basic() {
assert_eq!(pool_output_size(4, 2, 2, 0), Some(2));
assert_eq!(pool_output_size(5, 3, 1, 1), Some(5));
}
#[test]
fn pool_output_zero_stride() {
assert_eq!(pool_output_size(4, 2, 0, 0), None);
}
#[test]
fn pool_output_kernel_too_large() {
assert_eq!(pool_output_size(2, 5, 1, 0), None);
}
}