1use std::marker::PhantomData;
9
10use oxicuda_blas::GpuFloat;
11use oxicuda_driver::ffi::CUdeviceptr;
12use oxicuda_memory::DeviceBuffer;
13
14use crate::error::{DnnError, DnnResult};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27pub enum TensorLayout {
28 Nchw,
30 Nhwc,
32 Ncdhw,
34 Ndhwc,
36 RowMajor,
38}
39
40impl TensorLayout {
41 #[inline]
43 #[must_use]
44 pub const fn spatial_dims(self) -> usize {
45 match self {
46 Self::Nchw | Self::Nhwc => 2,
47 Self::Ncdhw | Self::Ndhwc => 3,
48 Self::RowMajor => 0,
49 }
50 }
51
52 #[inline]
54 #[must_use]
55 pub const fn expected_ndim(self) -> usize {
56 match self {
57 Self::Nchw | Self::Nhwc => 4,
58 Self::Ncdhw | Self::Ndhwc => 5,
59 Self::RowMajor => 2,
60 }
61 }
62
63 #[inline]
65 #[must_use]
66 pub const fn is_channels_last(self) -> bool {
67 matches!(self, Self::Nhwc | Self::Ndhwc)
68 }
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
81pub enum Activation {
82 Relu,
84 Gelu,
86 GeluTanh,
88 Silu,
90 Sigmoid,
92 Tanh,
94 None,
96}
97
98pub struct TensorDesc<T: GpuFloat> {
109 pub ptr: CUdeviceptr,
111 pub dims: Vec<u32>,
113 pub strides: Vec<u32>,
115 pub layout: TensorLayout,
117 _phantom: PhantomData<T>,
118}
119
120impl<T: GpuFloat> TensorDesc<T> {
121 pub fn nchw(buf: &DeviceBuffer<T>, n: u32, c: u32, h: u32, w: u32) -> DnnResult<Self> {
129 Self::validate_dims(&[n, c, h, w])?;
130 let dims = vec![n, c, h, w];
131 let strides = nchw_strides(c, h, w);
132 let desc = Self {
133 ptr: buf.as_device_ptr(),
134 dims,
135 strides,
136 layout: TensorLayout::Nchw,
137 _phantom: PhantomData,
138 };
139 desc.validate_buffer_size(buf)?;
140 Ok(desc)
141 }
142
143 pub fn nhwc(buf: &DeviceBuffer<T>, n: u32, c: u32, h: u32, w: u32) -> DnnResult<Self> {
149 Self::validate_dims(&[n, c, h, w])?;
150 let dims = vec![n, c, h, w];
151 let strides = nhwc_strides(c, h, w);
152 let desc = Self {
153 ptr: buf.as_device_ptr(),
154 dims,
155 strides,
156 layout: TensorLayout::Nhwc,
157 _phantom: PhantomData,
158 };
159 desc.validate_buffer_size(buf)?;
160 Ok(desc)
161 }
162
163 pub fn ncdhw(buf: &DeviceBuffer<T>, n: u32, c: u32, d: u32, h: u32, w: u32) -> DnnResult<Self> {
169 Self::validate_dims(&[n, c, d, h, w])?;
170 let dims = vec![n, c, d, h, w];
171 let strides = vec![c * d * h * w, d * h * w, h * w, w, 1];
172 let desc = Self {
173 ptr: buf.as_device_ptr(),
174 dims,
175 strides,
176 layout: TensorLayout::Ncdhw,
177 _phantom: PhantomData,
178 };
179 desc.validate_buffer_size(buf)?;
180 Ok(desc)
181 }
182
183 pub fn matrix(buf: &DeviceBuffer<T>, rows: u32, cols: u32) -> DnnResult<Self> {
190 Self::validate_dims(&[rows, cols])?;
191 let dims = vec![rows, cols];
192 let strides = vec![cols, 1];
193 let desc = Self {
194 ptr: buf.as_device_ptr(),
195 dims,
196 strides,
197 layout: TensorLayout::Nchw, _phantom: PhantomData,
199 };
200 desc.validate_buffer_size(buf)?;
201 Ok(desc)
202 }
203
204 pub fn from_raw(
209 ptr: CUdeviceptr,
210 dims: Vec<u32>,
211 strides: Vec<u32>,
212 layout: TensorLayout,
213 ) -> DnnResult<Self> {
214 if dims.len() != strides.len() {
215 return Err(DnnError::InvalidDimension(format!(
216 "dims length ({}) != strides length ({})",
217 dims.len(),
218 strides.len()
219 )));
220 }
221 if dims.is_empty() {
222 return Err(DnnError::InvalidDimension("empty dims".into()));
223 }
224 Ok(Self {
225 ptr,
226 dims,
227 strides,
228 layout,
229 _phantom: PhantomData,
230 })
231 }
232
233 #[inline]
235 #[must_use]
236 pub fn numel(&self) -> usize {
237 self.dims.iter().map(|&d| d as usize).product()
238 }
239
240 #[inline]
242 #[must_use]
243 pub fn ndim(&self) -> usize {
244 self.dims.len()
245 }
246
247 pub fn validate_buffer_size(&self, buf: &DeviceBuffer<T>) -> DnnResult<()> {
254 let required = self.numel() * T::SIZE;
255 let actual = buf.len() * T::SIZE;
256 if actual < required {
257 return Err(DnnError::BufferTooSmall {
258 expected: required,
259 actual,
260 });
261 }
262 Ok(())
263 }
264
265 fn validate_dims(dims: &[u32]) -> DnnResult<()> {
267 for (i, &d) in dims.iter().enumerate() {
268 if d == 0 {
269 return Err(DnnError::InvalidDimension(format!("dimension {i} is zero")));
270 }
271 }
272 Ok(())
273 }
274}
275
276pub struct TensorDescMut<T: GpuFloat> {
286 pub ptr: CUdeviceptr,
288 pub dims: Vec<u32>,
290 pub strides: Vec<u32>,
292 pub layout: TensorLayout,
294 _phantom: PhantomData<T>,
295}
296
297impl<T: GpuFloat> TensorDescMut<T> {
298 pub fn nchw(buf: &mut DeviceBuffer<T>, n: u32, c: u32, h: u32, w: u32) -> DnnResult<Self> {
304 validate_dims_helper(&[n, c, h, w])?;
305 let numel = (n as usize) * (c as usize) * (h as usize) * (w as usize);
306 validate_buf_size::<T>(buf.len(), numel)?;
307 Ok(Self {
308 ptr: buf.as_device_ptr(),
309 dims: vec![n, c, h, w],
310 strides: nchw_strides(c, h, w),
311 layout: TensorLayout::Nchw,
312 _phantom: PhantomData,
313 })
314 }
315
316 pub fn nhwc(buf: &mut DeviceBuffer<T>, n: u32, c: u32, h: u32, w: u32) -> DnnResult<Self> {
322 validate_dims_helper(&[n, c, h, w])?;
323 let numel = (n as usize) * (c as usize) * (h as usize) * (w as usize);
324 validate_buf_size::<T>(buf.len(), numel)?;
325 Ok(Self {
326 ptr: buf.as_device_ptr(),
327 dims: vec![n, c, h, w],
328 strides: nhwc_strides(c, h, w),
329 layout: TensorLayout::Nhwc,
330 _phantom: PhantomData,
331 })
332 }
333
334 pub fn matrix(buf: &mut DeviceBuffer<T>, rows: u32, cols: u32) -> DnnResult<Self> {
340 validate_dims_helper(&[rows, cols])?;
341 let numel = (rows as usize) * (cols as usize);
342 validate_buf_size::<T>(buf.len(), numel)?;
343 Ok(Self {
344 ptr: buf.as_device_ptr(),
345 dims: vec![rows, cols],
346 strides: vec![cols, 1],
347 layout: TensorLayout::Nchw,
348 _phantom: PhantomData,
349 })
350 }
351
352 pub fn from_raw(
354 ptr: CUdeviceptr,
355 dims: Vec<u32>,
356 strides: Vec<u32>,
357 layout: TensorLayout,
358 ) -> DnnResult<Self> {
359 if dims.len() != strides.len() {
360 return Err(DnnError::InvalidDimension(format!(
361 "dims length ({}) != strides length ({})",
362 dims.len(),
363 strides.len()
364 )));
365 }
366 if dims.is_empty() {
367 return Err(DnnError::InvalidDimension("empty dims".into()));
368 }
369 Ok(Self {
370 ptr,
371 dims,
372 strides,
373 layout,
374 _phantom: PhantomData,
375 })
376 }
377
378 #[inline]
380 #[must_use]
381 pub fn numel(&self) -> usize {
382 self.dims.iter().map(|&d| d as usize).product()
383 }
384
385 #[inline]
387 #[must_use]
388 pub fn ndim(&self) -> usize {
389 self.dims.len()
390 }
391
392 #[must_use]
394 pub fn as_immutable(&self) -> TensorDesc<T> {
395 TensorDesc {
396 ptr: self.ptr,
397 dims: self.dims.clone(),
398 strides: self.strides.clone(),
399 layout: self.layout,
400 _phantom: PhantomData,
401 }
402 }
403}
404
405#[derive(Debug, Clone)]
414pub struct ConvolutionDescriptor {
415 pub padding: Vec<u32>,
417 pub stride: Vec<u32>,
419 pub dilation: Vec<u32>,
421 pub groups: u32,
423}
424
425impl ConvolutionDescriptor {
426 pub fn conv2d(
433 pad_h: u32,
434 pad_w: u32,
435 stride_h: u32,
436 stride_w: u32,
437 dilation_h: u32,
438 dilation_w: u32,
439 groups: u32,
440 ) -> DnnResult<Self> {
441 if stride_h == 0 || stride_w == 0 {
442 return Err(DnnError::InvalidArgument("stride must be non-zero".into()));
443 }
444 if dilation_h == 0 || dilation_w == 0 {
445 return Err(DnnError::InvalidArgument(
446 "dilation must be non-zero".into(),
447 ));
448 }
449 if groups == 0 {
450 return Err(DnnError::InvalidArgument("groups must be non-zero".into()));
451 }
452 Ok(Self {
453 padding: vec![pad_h, pad_w],
454 stride: vec![stride_h, stride_w],
455 dilation: vec![dilation_h, dilation_w],
456 groups,
457 })
458 }
459
460 #[inline]
462 #[must_use]
463 pub fn spatial_dims(&self) -> usize {
464 self.padding.len()
465 }
466
467 pub fn output_size(
476 input: u32,
477 kernel: u32,
478 pad: u32,
479 stride: u32,
480 dilation: u32,
481 ) -> DnnResult<u32> {
482 let effective_kernel = dilation
483 .checked_mul(kernel.saturating_sub(1))
484 .and_then(|v| v.checked_add(1))
485 .ok_or_else(|| DnnError::InvalidDimension("effective kernel size overflow".into()))?;
486 let padded_input = input
487 .checked_add(2 * pad)
488 .ok_or_else(|| DnnError::InvalidDimension("padded input overflow".into()))?;
489 if padded_input < effective_kernel {
490 return Err(DnnError::InvalidDimension(format!(
491 "padded input ({padded_input}) < effective kernel ({effective_kernel})"
492 )));
493 }
494 Ok((padded_input - effective_kernel) / stride + 1)
495 }
496}
497
498#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
508pub enum ConvAlgorithm {
509 ImplicitGemm,
511 Im2colGemm,
514 Winograd,
517 Direct,
519 FftConv,
521}
522
523#[derive(Debug, Clone, Copy)]
531pub struct TileConfig {
532 pub tile_m: u32,
534 pub tile_n: u32,
536 pub tile_k: u32,
538 pub warp_m: u32,
540 pub warp_n: u32,
542 pub stages: u32,
544}
545
546impl TileConfig {
547 #[must_use]
549 pub fn default_conv(sm: oxicuda_ptx::arch::SmVersion) -> Self {
550 use oxicuda_ptx::arch::SmVersion;
551 match sm {
552 SmVersion::Sm90 | SmVersion::Sm90a | SmVersion::Sm100 | SmVersion::Sm120 => Self {
553 tile_m: 128,
554 tile_n: 128,
555 tile_k: 32,
556 warp_m: 64,
557 warp_n: 64,
558 stages: 4,
559 },
560 SmVersion::Sm80 | SmVersion::Sm86 | SmVersion::Sm89 => Self {
561 tile_m: 128,
562 tile_n: 128,
563 tile_k: 32,
564 warp_m: 64,
565 warp_n: 64,
566 stages: 3,
567 },
568 SmVersion::Sm75 => Self {
569 tile_m: 64,
570 tile_n: 64,
571 tile_k: 32,
572 warp_m: 32,
573 warp_n: 32,
574 stages: 2,
575 },
576 }
577 }
578}
579
580#[must_use]
590pub fn pool_output_size(
591 input_dim: u32,
592 kernel_size: u32,
593 stride: u32,
594 padding: u32,
595) -> Option<u32> {
596 if stride == 0 || kernel_size == 0 {
597 return None;
598 }
599 let effective = input_dim + 2 * padding;
600 if effective < kernel_size {
601 return None;
602 }
603 Some((effective - kernel_size) / stride + 1)
604}
605
606fn nchw_strides(c: u32, h: u32, w: u32) -> Vec<u32> {
612 vec![c * h * w, h * w, w, 1]
613}
614
615fn nhwc_strides(c: u32, h: u32, w: u32) -> Vec<u32> {
617 vec![h * w * c, 1, w * c, c]
618}
619
620fn validate_dims_helper(dims: &[u32]) -> DnnResult<()> {
622 for (i, &d) in dims.iter().enumerate() {
623 if d == 0 {
624 return Err(DnnError::InvalidDimension(format!("dimension {i} is zero")));
625 }
626 }
627 Ok(())
628}
629
630fn validate_buf_size<T: GpuFloat>(buf_len: usize, required_numel: usize) -> DnnResult<()> {
632 let required = required_numel * T::SIZE;
633 let actual = buf_len * T::SIZE;
634 if actual < required {
635 return Err(DnnError::BufferTooSmall {
636 expected: required,
637 actual,
638 });
639 }
640 Ok(())
641}
642
643#[cfg(test)]
648mod tests {
649 use super::*;
650
651 #[test]
652 fn nchw_stride_order() {
653 let s = nchw_strides(3, 4, 5);
654 assert_eq!(s, vec![60, 20, 5, 1]);
655 }
656
657 #[test]
658 fn nhwc_stride_order() {
659 let s = nhwc_strides(3, 4, 5);
660 assert_eq!(s, vec![60, 1, 15, 3]);
662 }
663
664 #[test]
665 fn conv_output_size_basic() {
666 let out = ConvolutionDescriptor::output_size(32, 3, 1, 1, 1);
668 assert_eq!(out.ok(), Some(32));
669 }
670
671 #[test]
672 fn conv_output_size_strided() {
673 let out = ConvolutionDescriptor::output_size(32, 3, 1, 2, 1);
675 assert_eq!(out.ok(), Some(16));
676 }
677
678 #[test]
679 fn conv_output_size_dilated() {
680 let out = ConvolutionDescriptor::output_size(32, 3, 2, 1, 2);
682 assert_eq!(out.ok(), Some(32));
683 }
684
685 #[test]
686 fn conv_output_size_too_small() {
687 let out = ConvolutionDescriptor::output_size(3, 5, 0, 1, 1);
688 assert!(out.is_err());
689 }
690
691 #[test]
692 fn conv2d_zero_stride_rejected() {
693 let r = ConvolutionDescriptor::conv2d(0, 0, 0, 1, 1, 1, 1);
694 assert!(r.is_err());
695 }
696
697 #[test]
698 fn conv2d_zero_groups_rejected() {
699 let r = ConvolutionDescriptor::conv2d(0, 0, 1, 1, 1, 1, 0);
700 assert!(r.is_err());
701 }
702
703 #[test]
704 fn tensor_layout_spatial_dims() {
705 assert_eq!(TensorLayout::Nchw.spatial_dims(), 2);
706 assert_eq!(TensorLayout::Nhwc.spatial_dims(), 2);
707 assert_eq!(TensorLayout::Ncdhw.spatial_dims(), 3);
708 assert_eq!(TensorLayout::Ndhwc.spatial_dims(), 3);
709 }
710
711 #[test]
712 fn tensor_layout_expected_ndim() {
713 assert_eq!(TensorLayout::Nchw.expected_ndim(), 4);
714 assert_eq!(TensorLayout::Ncdhw.expected_ndim(), 5);
715 }
716
717 #[test]
718 fn from_raw_mismatched_lengths() {
719 let r = TensorDesc::<f32>::from_raw(0, vec![1, 2], vec![1], TensorLayout::Nchw);
720 assert!(r.is_err());
721 }
722
723 #[test]
724 fn from_raw_empty_dims() {
725 let r = TensorDesc::<f32>::from_raw(0, vec![], vec![], TensorLayout::Nchw);
726 assert!(r.is_err());
727 }
728
729 #[test]
730 fn activation_variants_are_distinct() {
731 assert_ne!(Activation::Relu, Activation::Gelu);
732 assert_ne!(Activation::Gelu, Activation::GeluTanh);
733 assert_ne!(Activation::Silu, Activation::Sigmoid);
734 assert_eq!(Activation::None, Activation::None);
735 }
736
737 #[test]
738 fn conv_algorithm_debug() {
739 let _ = format!("{:?}", ConvAlgorithm::Winograd);
740 }
741
742 #[test]
743 fn pool_output_basic() {
744 assert_eq!(pool_output_size(4, 2, 2, 0), Some(2));
745 assert_eq!(pool_output_size(5, 3, 1, 1), Some(5));
746 }
747
748 #[test]
749 fn pool_output_zero_stride() {
750 assert_eq!(pool_output_size(4, 2, 0, 0), None);
751 }
752
753 #[test]
754 fn pool_output_kernel_too_large() {
755 assert_eq!(pool_output_size(2, 5, 1, 0), None);
756 }
757}