use crate::codegen::dialect::gpu;
use burn_tensor::Element;
pub trait JitElement:
burn_tensor::Element + core::fmt::Debug + Send + Sync + 'static + Clone + bytemuck::Pod
where
Self: Sized,
{
fn type_name() -> &'static str;
fn as_bytes(slice: &[Self]) -> &[u8];
fn from_bytes(bytes: &[u8]) -> &[Self];
fn gpu_elem() -> gpu::Elem;
fn maximum_value() -> Self;
fn minimum_value() -> Self;
}
pub trait FloatElement: JitElement + Element {}
pub trait IntElement: JitElement + Element {}
impl JitElement for u32 {
fn type_name() -> &'static str {
"u32"
}
fn as_bytes(slice: &[Self]) -> &[u8] {
bytemuck::cast_slice(slice)
}
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
fn gpu_elem() -> gpu::Elem {
gpu::Elem::UInt
}
fn maximum_value() -> Self {
u32::MAX
}
fn minimum_value() -> Self {
u32::MIN
}
}
impl JitElement for i32 {
fn type_name() -> &'static str {
"i32"
}
fn as_bytes(slice: &[Self]) -> &[u8] {
bytemuck::cast_slice(slice)
}
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
fn gpu_elem() -> gpu::Elem {
gpu::Elem::Int
}
fn maximum_value() -> Self {
i32::MAX - 1
}
fn minimum_value() -> Self {
i32::MIN + 1
}
}
impl JitElement for f32 {
fn type_name() -> &'static str {
"f32"
}
fn as_bytes(slice: &[Self]) -> &[u8] {
bytemuck::cast_slice(slice)
}
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
fn gpu_elem() -> gpu::Elem {
gpu::Elem::Float
}
fn maximum_value() -> Self {
f32::MAX
}
fn minimum_value() -> Self {
f32::MIN
}
}
impl FloatElement for f32 {}
impl IntElement for i32 {}