cubecl-core 0.10.0-pre.3

CubeCL core create
Documentation
use cubecl_ir::{ConstantValue, Scope, StorageType, Type};
use half::{bf16, f16};

use crate::{
    self as cubecl,
    ir::{ElemType, FloatKind},
    prelude::*,
};

use super::Numeric;

mod fp4;
mod fp6;
mod fp8;
mod relaxed;
mod tensor_float;

/// Floating point numbers. Used as input in float kernels
pub trait Float:
    Numeric
    + FloatOps
    + Exp
    + Log
    + Log1p
    + Cos
    + Sin
    + Tan
    + Tanh
    + Sinh
    + Cosh
    + ArcCos
    + ArcSin
    + ArcTan
    + ArcSinh
    + ArcCosh
    + ArcTanh
    + Degrees
    + Radians
    + ArcTan2
    + Powf
    + Powi<i32>
    + Hypot
    + Rhypot
    + Sqrt
    + InverseSqrt
    + Round
    + Floor
    + Ceil
    + Trunc
    + Erf
    + Recip
    + Magnitude
    + Normalize
    + Dot
    + IsNan
    + IsInf
    + Into<Self::ExpandType>
    + core::ops::Neg<Output = Self>
    + core::cmp::PartialOrd
    + core::cmp::PartialEq
{
    const DIGITS: u32;
    const EPSILON: Self;
    const INFINITY: Self;
    const MANTISSA_DIGITS: u32;
    const MAX_10_EXP: i32;
    const MAX_EXP: i32;
    const MIN_10_EXP: i32;
    const MIN_EXP: i32;
    const MIN_POSITIVE: Self;
    const NAN: Self;
    const NEG_INFINITY: Self;
    const RADIX: u32;

    fn new(val: f32) -> Self;
    fn __expand_new(scope: &mut Scope, val: f32) -> <Self as CubeType>::ExpandType {
        __expand_new(scope, val)
    }
}

#[cube]
pub trait FloatOps: CubePrimitive + PartialOrd + Sized {
    fn min(self, other: Self) -> Self {
        cubecl::prelude::min(self, other)
    }

    fn max(self, other: Self) -> Self {
        cubecl::prelude::max(self, other)
    }

    fn clamp(self, min: Self, max: Self) -> Self {
        clamp(self, min, max)
    }
}

impl<T: Float> FloatOps for T {}
impl<T: FloatOps + CubePrimitive> FloatOpsExpand for NativeExpand<T> {
    fn __expand_min_method(self, scope: &mut Scope, other: Self) -> Self {
        min::expand(scope, self, other)
    }

    fn __expand_max_method(self, scope: &mut Scope, other: Self) -> Self {
        max::expand(scope, self, other)
    }

    fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self {
        clamp::expand(scope, self, min, max)
    }
}

macro_rules! impl_float {
    (half $primitive:ident, $kind:ident) => {
        impl_float!($primitive, $kind, |val| $primitive::from_f64(val));
    };
    ($primitive:ident, $kind:ident) => {
        impl_float!($primitive, $kind, |val| val as $primitive);
    };
    ($primitive:ident, $kind:ident, $new:expr) => {
        impl CubeType for $primitive {
            type ExpandType = NativeExpand<$primitive>;
        }

        impl Scalar for $primitive {}
        impl CubePrimitive for $primitive {
            type Scalar = Self;
            type Size = Const<1>;
            type WithScalar<S: Scalar> = S;

            /// Return the element type to use on GPU
            fn as_type_native() -> Option<Type> {
                Some(StorageType::Scalar(ElemType::Float(FloatKind::$kind)).into())
            }

            fn from_const_value(value: ConstantValue) -> Self {
                let ConstantValue::Float(value) = value else {
                    unreachable!()
                };
                $new(value)
            }
        }

        impl IntoRuntime for $primitive {
            fn __expand_runtime_method(self, _scope: &mut Scope) -> NativeExpand<Self> {
                self.into()
            }
        }

        impl Numeric for $primitive {
            fn min_value() -> Self {
                <Self as num_traits::Float>::min_value()
            }
            fn max_value() -> Self {
                <Self as num_traits::Float>::max_value()
            }
        }

        impl NativeAssign for $primitive {}

        impl IntoMut for $primitive {
            fn into_mut(self, _scope: &mut Scope) -> Self {
                self
            }
        }

        impl Float for $primitive {
            const DIGITS: u32 = $primitive::DIGITS;
            const EPSILON: Self = $primitive::EPSILON;
            const INFINITY: Self = $primitive::INFINITY;
            const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS;
            const MAX_10_EXP: i32 = $primitive::MAX_10_EXP;
            const MAX_EXP: i32 = $primitive::MAX_EXP;
            const MIN_10_EXP: i32 = $primitive::MIN_10_EXP;
            const MIN_EXP: i32 = $primitive::MIN_EXP;
            const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE;
            const NAN: Self = $primitive::NAN;
            const NEG_INFINITY: Self = $primitive::NEG_INFINITY;
            const RADIX: u32 = $primitive::RADIX;

            fn new(val: f32) -> Self {
                $new(val as f64)
            }
        }
    };
}

impl_float!(half f16, F16);
impl_float!(half bf16, BF16);
impl_float!(f32, F32);
impl_float!(f64, F64);