use cubecl::{
prelude::{Float, Int, Numeric},
CubeElement,
};
pub trait JitElement: burn_tensor::Element + CubeElement + PartialEq + Numeric {}
pub trait FloatElement: JitElement + Float {}
pub trait IntElement: JitElement + Int {}
impl JitElement for u32 {}
impl JitElement for i32 {}
impl JitElement for f32 {}
impl JitElement for half::f16 {}
impl JitElement for half::bf16 {}
impl FloatElement for f32 {}
impl FloatElement for half::bf16 {}
impl FloatElement for half::f16 {}
impl IntElement for i32 {}