use cubecl_common::{e4m3, e5m2, ue8m0};
use cubecl_ir::{ConstantScalarValue, ElemType, ExpandElement, FloatKind, Scope, StorageType};
use crate::{
Runtime,
compute::KernelLauncher,
prelude::{
CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, IntoRuntime, Numeric,
ScalarArgSettings, into_mut_expand_element, into_runtime_expand_element,
},
};
impl CubeType for e4m3 {
type ExpandType = ExpandElementTyped<e4m3>;
}
impl CubePrimitive for e4m3 {
fn as_type_native() -> Option<StorageType> {
Some(ElemType::Float(FloatKind::E4M3).into())
}
fn from_const_value(value: ConstantScalarValue) -> Self {
let ConstantScalarValue::Float(value, _) = value else {
unreachable!()
};
e4m3::from_f64(value)
}
}
impl IntoRuntime for e4m3 {
fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
let elem: ExpandElementTyped<Self> = self.into();
into_runtime_expand_element(scope, elem).into()
}
}
impl Numeric for e4m3 {
fn min_value() -> Self {
Self::from_f64(Self::MIN)
}
fn max_value() -> Self {
Self::from_f64(Self::MAX)
}
}
impl ExpandElementIntoMut for e4m3 {
fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
into_mut_expand_element(scope, elem)
}
}
impl ScalarArgSettings for e4m3 {
fn register<R: Runtime>(&self, _settings: &mut KernelLauncher<R>) {
todo!("Not yet supported for scalars")
}
}
impl CubeType for e5m2 {
type ExpandType = ExpandElementTyped<e5m2>;
}
impl CubePrimitive for e5m2 {
fn as_type_native() -> Option<StorageType> {
Some(ElemType::Float(FloatKind::E5M2).into())
}
fn from_const_value(value: ConstantScalarValue) -> Self {
let ConstantScalarValue::Float(value, _) = value else {
unreachable!()
};
e5m2::from_f64(value)
}
}
impl IntoRuntime for e5m2 {
fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
let elem: ExpandElementTyped<Self> = self.into();
into_runtime_expand_element(scope, elem).into()
}
}
impl Numeric for e5m2 {
fn min_value() -> Self {
Self::from_f64(Self::MIN)
}
fn max_value() -> Self {
Self::from_f64(Self::MAX)
}
}
impl ExpandElementIntoMut for e5m2 {
fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
into_mut_expand_element(scope, elem)
}
}
impl ScalarArgSettings for e5m2 {
fn register<R: Runtime>(&self, _settings: &mut KernelLauncher<R>) {
todo!("Not yet supported for scalars")
}
}
impl CubeType for ue8m0 {
type ExpandType = ExpandElementTyped<ue8m0>;
}
impl CubePrimitive for ue8m0 {
fn as_type_native() -> Option<StorageType> {
Some(ElemType::Float(FloatKind::UE8M0).into())
}
fn from_const_value(value: ConstantScalarValue) -> Self {
let ConstantScalarValue::Float(value, _) = value else {
unreachable!()
};
ue8m0::from_f64(value)
}
}
impl IntoRuntime for ue8m0 {
fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
let elem: ExpandElementTyped<Self> = self.into();
into_runtime_expand_element(scope, elem).into()
}
}
impl Numeric for ue8m0 {
fn min_value() -> Self {
Self::from_f64(Self::MIN)
}
fn max_value() -> Self {
Self::from_f64(Self::MAX)
}
}
impl ExpandElementIntoMut for ue8m0 {
fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
into_mut_expand_element(scope, elem)
}
}
impl ScalarArgSettings for ue8m0 {
fn register<R: Runtime>(&self, _settings: &mut KernelLauncher<R>) {
todo!("Not yet supported for scalars")
}
}