use std::marker::PhantomData;
use oxicuda_driver::ffi::CUdeviceptr;
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::ir::PtxType;
use crate::error::{BlasError, BlasResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MathMode {
Default,
TensorCore,
MaxPerformance,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PointerMode {
Host,
Device,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Layout {
RowMajor,
ColMajor,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Transpose {
NoTrans,
Trans,
ConjTrans,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FillMode {
Upper,
Lower,
Full,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Side {
Left,
Right,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DiagType {
NonUnit,
Unit,
}
pub trait GpuFloat: Copy + Send + Sync + 'static + std::fmt::Debug + PartialOrd {
const PTX_TYPE: PtxType;
const SIZE: usize;
const NAME: &'static str;
const TENSOR_CORE_ELIGIBLE: bool;
type Accumulator: GpuFloat;
fn to_bits_u64(self) -> u64;
fn from_bits_u64(bits: u64) -> Self;
fn gpu_zero() -> Self;
fn gpu_one() -> Self;
#[inline]
fn size_u32() -> u32 {
Self::SIZE as u32
}
}
impl GpuFloat for f32 {
const PTX_TYPE: PtxType = PtxType::F32;
const SIZE: usize = 4;
const NAME: &'static str = "f32";
const TENSOR_CORE_ELIGIBLE: bool = true;
type Accumulator = f32;
#[inline]
fn to_bits_u64(self) -> u64 {
u64::from(self.to_bits())
}
#[inline]
fn from_bits_u64(bits: u64) -> Self {
f32::from_bits(bits as u32)
}
#[inline]
fn gpu_zero() -> Self {
0.0
}
#[inline]
fn gpu_one() -> Self {
1.0
}
}
impl GpuFloat for f64 {
const PTX_TYPE: PtxType = PtxType::F64;
const SIZE: usize = 8;
const NAME: &'static str = "f64";
const TENSOR_CORE_ELIGIBLE: bool = true;
type Accumulator = f64;
#[inline]
fn to_bits_u64(self) -> u64 {
self.to_bits()
}
#[inline]
fn from_bits_u64(bits: u64) -> Self {
f64::from_bits(bits)
}
#[inline]
fn gpu_zero() -> Self {
0.0
}
#[inline]
fn gpu_one() -> Self {
1.0
}
}
#[cfg(feature = "f16")]
impl GpuFloat for half::f16 {
const PTX_TYPE: PtxType = PtxType::F16;
const SIZE: usize = 2;
const NAME: &'static str = "f16";
const TENSOR_CORE_ELIGIBLE: bool = true;
type Accumulator = f32;
#[inline]
fn to_bits_u64(self) -> u64 {
u64::from(self.to_bits())
}
#[inline]
fn from_bits_u64(bits: u64) -> Self {
half::f16::from_bits(bits as u16)
}
#[inline]
fn gpu_zero() -> Self {
half::f16::ZERO
}
#[inline]
fn gpu_one() -> Self {
half::f16::ONE
}
}
#[cfg(feature = "f16")]
impl GpuFloat for half::bf16 {
const PTX_TYPE: PtxType = PtxType::BF16;
const SIZE: usize = 2;
const NAME: &'static str = "bf16";
const TENSOR_CORE_ELIGIBLE: bool = true;
type Accumulator = f32;
#[inline]
fn to_bits_u64(self) -> u64 {
u64::from(self.to_bits())
}
#[inline]
fn from_bits_u64(bits: u64) -> Self {
half::bf16::from_bits(bits as u16)
}
#[inline]
fn gpu_zero() -> Self {
half::bf16::ZERO
}
#[inline]
fn gpu_one() -> Self {
half::bf16::ONE
}
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
#[repr(transparent)]
pub struct E4M3(pub u8);
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
#[repr(transparent)]
pub struct E5M2(pub u8);
unsafe impl Send for E4M3 {}
unsafe impl Sync for E4M3 {}
unsafe impl Send for E5M2 {}
unsafe impl Sync for E5M2 {}
impl GpuFloat for E4M3 {
const PTX_TYPE: PtxType = PtxType::E4M3;
const SIZE: usize = 1;
const NAME: &'static str = "e4m3";
const TENSOR_CORE_ELIGIBLE: bool = true;
type Accumulator = f32;
#[inline]
fn to_bits_u64(self) -> u64 {
u64::from(self.0)
}
#[inline]
fn from_bits_u64(bits: u64) -> Self {
Self(bits as u8)
}
#[inline]
fn gpu_zero() -> Self {
Self(0x00)
}
#[inline]
fn gpu_one() -> Self {
Self(0x38)
}
}
impl GpuFloat for E5M2 {
const PTX_TYPE: PtxType = PtxType::E5M2;
const SIZE: usize = 1;
const NAME: &'static str = "e5m2";
const TENSOR_CORE_ELIGIBLE: bool = true;
type Accumulator = f32;
#[inline]
fn to_bits_u64(self) -> u64 {
u64::from(self.0)
}
#[inline]
fn from_bits_u64(bits: u64) -> Self {
Self(bits as u8)
}
#[inline]
fn gpu_zero() -> Self {
Self(0x00)
}
#[inline]
fn gpu_one() -> Self {
Self(0x3C)
}
}
#[derive(Debug, Clone, Copy)]
pub struct VectorDesc {
pub n: u32,
pub inc: u32,
}
impl VectorDesc {
#[must_use]
pub fn new(n: u32, inc: u32) -> Self {
Self { n, inc }
}
#[must_use]
pub fn required_elements(&self) -> usize {
if self.n == 0 {
return 0;
}
1 + (self.n as usize - 1) * self.inc as usize
}
}
#[derive(Debug, Clone, Copy)]
pub struct MatrixDesc<T: GpuFloat> {
pub ptr: CUdeviceptr,
pub rows: u32,
pub cols: u32,
pub ld: u32,
pub layout: Layout,
_phantom: PhantomData<T>,
}
impl<T: GpuFloat> MatrixDesc<T> {
pub fn from_buffer(
buf: &DeviceBuffer<T>,
rows: u32,
cols: u32,
layout: Layout,
) -> BlasResult<Self> {
let required = rows as usize * cols as usize;
if buf.len() < required {
return Err(BlasError::BufferTooSmall {
expected: required,
actual: buf.len(),
});
}
let ld = match layout {
Layout::RowMajor => cols,
Layout::ColMajor => rows,
};
Ok(Self {
ptr: buf.as_device_ptr(),
rows,
cols,
ld,
layout,
_phantom: PhantomData,
})
}
pub fn from_raw(ptr: CUdeviceptr, rows: u32, cols: u32, ld: u32, layout: Layout) -> Self {
Self {
ptr,
rows,
cols,
ld,
layout,
_phantom: PhantomData,
}
}
#[must_use]
pub fn with_ld(mut self, ld: u32) -> Self {
self.ld = ld;
self
}
#[must_use]
pub fn numel(&self) -> usize {
self.rows as usize * self.cols as usize
}
#[must_use]
pub fn storage_bytes(&self) -> usize {
let major = match self.layout {
Layout::RowMajor => self.rows,
Layout::ColMajor => self.cols,
};
major as usize * self.ld as usize * T::SIZE
}
#[must_use]
pub fn effective_dims(&self, trans: Transpose) -> (u32, u32) {
match trans {
Transpose::NoTrans => (self.rows, self.cols),
Transpose::Trans | Transpose::ConjTrans => (self.cols, self.rows),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct MatrixDescMut<T: GpuFloat> {
pub ptr: CUdeviceptr,
pub rows: u32,
pub cols: u32,
pub ld: u32,
pub layout: Layout,
_phantom: PhantomData<T>,
}
impl<T: GpuFloat> MatrixDescMut<T> {
pub fn from_buffer(
buf: &mut DeviceBuffer<T>,
rows: u32,
cols: u32,
layout: Layout,
) -> BlasResult<Self> {
let required = rows as usize * cols as usize;
if buf.len() < required {
return Err(BlasError::BufferTooSmall {
expected: required,
actual: buf.len(),
});
}
let ld = match layout {
Layout::RowMajor => cols,
Layout::ColMajor => rows,
};
Ok(Self {
ptr: buf.as_device_ptr(),
rows,
cols,
ld,
layout,
_phantom: PhantomData,
})
}
pub fn from_raw(ptr: CUdeviceptr, rows: u32, cols: u32, ld: u32, layout: Layout) -> Self {
Self {
ptr,
rows,
cols,
ld,
layout,
_phantom: PhantomData,
}
}
#[must_use]
pub fn with_ld(mut self, ld: u32) -> Self {
self.ld = ld;
self
}
#[must_use]
pub fn numel(&self) -> usize {
self.rows as usize * self.cols as usize
}
#[must_use]
pub fn storage_bytes(&self) -> usize {
let major = match self.layout {
Layout::RowMajor => self.rows,
Layout::ColMajor => self.cols,
};
major as usize * self.ld as usize * T::SIZE
}
#[must_use]
pub fn effective_dims(&self, trans: Transpose) -> (u32, u32) {
match trans {
Transpose::NoTrans => (self.rows, self.cols),
Transpose::Trans | Transpose::ConjTrans => (self.cols, self.rows),
}
}
#[must_use]
pub fn as_immutable(&self) -> MatrixDesc<T> {
MatrixDesc {
ptr: self.ptr,
rows: self.rows,
cols: self.cols,
ld: self.ld,
layout: self.layout,
_phantom: PhantomData,
}
}
}