use cubecl_ir::{ConstantValue, Scope, StorageType, Type};
use half::{bf16, f16};
use crate::{
self as cubecl,
ir::{ElemType, FloatKind},
prelude::*,
};
use super::Numeric;
mod fp4;
mod fp6;
mod fp8;
mod relaxed;
mod tensor_float;
pub trait Float:
Numeric
+ FloatOps
+ Exp
+ Log
+ Log1p
+ Cos
+ Sin
+ Tan
+ Tanh
+ Sinh
+ Cosh
+ ArcCos
+ ArcSin
+ ArcTan
+ ArcSinh
+ ArcCosh
+ ArcTanh
+ Degrees
+ Radians
+ ArcTan2
+ Powf
+ Powi<i32>
+ Hypot
+ Rhypot
+ Sqrt
+ InverseSqrt
+ Round
+ Floor
+ Ceil
+ Trunc
+ Erf
+ Recip
+ Magnitude
+ Normalize
+ Dot
+ IsNan
+ IsInf
+ Into<Self::ExpandType>
+ core::ops::Neg<Output = Self>
+ core::cmp::PartialOrd
+ core::cmp::PartialEq
{
const DIGITS: u32;
const EPSILON: Self;
const INFINITY: Self;
const MANTISSA_DIGITS: u32;
const MAX_10_EXP: i32;
const MAX_EXP: i32;
const MIN_10_EXP: i32;
const MIN_EXP: i32;
const MIN_POSITIVE: Self;
const NAN: Self;
const NEG_INFINITY: Self;
const RADIX: u32;
fn new(val: f32) -> Self;
fn __expand_new(scope: &mut Scope, val: f32) -> <Self as CubeType>::ExpandType {
__expand_new(scope, val)
}
}
#[cube]
pub trait FloatOps: CubePrimitive + PartialOrd + Sized {
fn min(self, other: Self) -> Self {
cubecl::prelude::min(self, other)
}
fn max(self, other: Self) -> Self {
cubecl::prelude::max(self, other)
}
fn clamp(self, min: Self, max: Self) -> Self {
clamp(self, min, max)
}
}
impl<T: Float> FloatOps for T {}
impl<T: FloatOps + CubePrimitive> FloatOpsExpand for NativeExpand<T> {
fn __expand_min_method(self, scope: &mut Scope, other: Self) -> Self {
min::expand(scope, self, other)
}
fn __expand_max_method(self, scope: &mut Scope, other: Self) -> Self {
max::expand(scope, self, other)
}
fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self {
clamp::expand(scope, self, min, max)
}
}
macro_rules! impl_float {
(half $primitive:ident, $kind:ident) => {
impl_float!($primitive, $kind, |val| $primitive::from_f64(val));
};
($primitive:ident, $kind:ident) => {
impl_float!($primitive, $kind, |val| val as $primitive);
};
($primitive:ident, $kind:ident, $new:expr) => {
impl CubeType for $primitive {
type ExpandType = NativeExpand<$primitive>;
}
impl Scalar for $primitive {}
impl CubePrimitive for $primitive {
type Scalar = Self;
type Size = Const<1>;
type WithScalar<S: Scalar> = S;
fn as_type_native() -> Option<Type> {
Some(StorageType::Scalar(ElemType::Float(FloatKind::$kind)).into())
}
fn from_const_value(value: ConstantValue) -> Self {
let ConstantValue::Float(value) = value else {
unreachable!()
};
$new(value)
}
}
impl IntoRuntime for $primitive {
fn __expand_runtime_method(self, _scope: &mut Scope) -> NativeExpand<Self> {
self.into()
}
}
impl Numeric for $primitive {
fn min_value() -> Self {
<Self as num_traits::Float>::min_value()
}
fn max_value() -> Self {
<Self as num_traits::Float>::max_value()
}
}
impl NativeAssign for $primitive {}
impl IntoMut for $primitive {
fn into_mut(self, _scope: &mut Scope) -> Self {
self
}
}
impl Float for $primitive {
const DIGITS: u32 = $primitive::DIGITS;
const EPSILON: Self = $primitive::EPSILON;
const INFINITY: Self = $primitive::INFINITY;
const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS;
const MAX_10_EXP: i32 = $primitive::MAX_10_EXP;
const MAX_EXP: i32 = $primitive::MAX_EXP;
const MIN_10_EXP: i32 = $primitive::MIN_10_EXP;
const MIN_EXP: i32 = $primitive::MIN_EXP;
const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE;
const NAN: Self = $primitive::NAN;
const NEG_INFINITY: Self = $primitive::NEG_INFINITY;
const RADIX: u32 = $primitive::RADIX;
fn new(val: f32) -> Self {
$new(val as f64)
}
}
};
}
impl_float!(half f16, F16);
impl_float!(half bf16, BF16);
impl_float!(f32, F32);
impl_float!(f64, F64);