pub trait KernelSource {
const NAME: &'static str;
const CODE: &'static str;
}
pub struct BinaryKernel;
impl KernelSource for BinaryKernel {
const NAME: &'static str = "binary";
const CODE: &'static str = include_str!("kernels/binary.hip");
}
pub struct UnaryKernel;
impl KernelSource for UnaryKernel {
const NAME: &'static str = "unary";
const CODE: &'static str = include_str!("kernels/unary.hip");
}
pub struct AffineKernel;
impl KernelSource for AffineKernel {
const NAME: &'static str = "affine";
const CODE: &'static str = include_str!("kernels/affine.hip");
}
pub struct FillKernel;
impl KernelSource for FillKernel {
const NAME: &'static str = "fill";
const CODE: &'static str = include_str!("kernels/fill.hip");
}
pub struct ReduceKernel;
impl KernelSource for ReduceKernel {
const NAME: &'static str = "reduce";
const CODE: &'static str = include_str!("kernels/reduce.hip");
}
pub struct ConvKernel;
impl KernelSource for ConvKernel {
const NAME: &'static str = "conv";
const CODE: &'static str = include_str!("kernels/conv.hip");
}
pub struct IndexingKernel;
impl KernelSource for IndexingKernel {
const NAME: &'static str = "indexing";
const CODE: &'static str = include_str!("kernels/indexing.hip");
}
pub struct CastKernel;
impl KernelSource for CastKernel {
const NAME: &'static str = "cast";
const CODE: &'static str = include_str!("kernels/cast.hip");
}
pub struct TernaryKernel;
impl KernelSource for TernaryKernel {
const NAME: &'static str = "ternary";
const CODE: &'static str = include_str!("kernels/ternary.hip");
}
pub struct SortKernel;
impl KernelSource for SortKernel {
const NAME: &'static str = "sort";
const CODE: &'static str = include_str!("kernels/sort.hip");
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BinaryOp {
Add,
Sub,
Mul,
Div,
Minimum,
Maximum,
}
impl BinaryOp {
pub fn kernel_name(&self) -> &'static str {
match self {
BinaryOp::Add => "badd",
BinaryOp::Sub => "bsub",
BinaryOp::Mul => "bmul",
BinaryOp::Div => "bdiv",
BinaryOp::Minimum => "bminimum",
BinaryOp::Maximum => "bmaximum",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UnaryOp {
Copy,
Relu,
Sigmoid,
Tan,
Exp,
Log,
Sin,
Cos,
Sqrt,
Abs,
Neg,
Recip,
Floor,
Ceil,
Round,
Gelu,
Silu,
Erf,
}
impl UnaryOp {
pub fn kernel_name(&self) -> &'static str {
match self {
UnaryOp::Copy => "ucopy",
UnaryOp::Relu => "urelu",
UnaryOp::Sigmoid => "usigmoid",
UnaryOp::Tan => "utan",
UnaryOp::Exp => "uexp",
UnaryOp::Log => "ulog",
UnaryOp::Sin => "usin",
UnaryOp::Cos => "ucos",
UnaryOp::Sqrt => "usqrt",
UnaryOp::Abs => "uabs",
UnaryOp::Neg => "uneg",
UnaryOp::Recip => "urecip",
UnaryOp::Floor => "ufloor",
UnaryOp::Ceil => "uceil",
UnaryOp::Round => "uround",
UnaryOp::Gelu => "ugelu",
UnaryOp::Silu => "usilu",
UnaryOp::Erf => "uerf",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DType {
BF16,
F16,
F32,
F64,
I64,
U32,
U8,
}
impl DType {
pub fn size_in_bytes(&self) -> usize {
match self {
Self::U8 => 1,
Self::U32 => 4,
Self::I64 => 8,
Self::BF16 => 2,
Self::F16 => 2,
Self::F32 => 4,
Self::F64 => 8,
}
}
}
pub fn dtype_suffix<T: Copy + Send + Sync + 'static>() -> &'static str {
let type_name = std::any::type_name::<T>();
if type_name.contains("f32") {
"f32"
} else if type_name.contains("f64") {
"f64"
} else if type_name.contains("u8") {
"u8"
} else if type_name.contains("u32") {
"u32"
} else if type_name.contains("i64") {
"i64"
} else {
panic!("Unsupported dtype for kernel: {}", type_name)
}
}