cubecl-core 0.7.0

CubeCL core create
Documentation
use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
use cubecl_ir::{Bitwise, Comparison, Operator, Type};
use half::{bf16, f16};

use crate::{
    flex32,
    ir::{Arithmetic, ExpandElement, Scope},
    prelude::{CubePrimitive, ExpandElementTyped},
    tf32, unexpanded,
};

use super::base::{unary_expand, unary_expand_fixed_output};

pub mod not {
    use super::*;

    pub fn expand(scope: &mut Scope, x: ExpandElementTyped<bool>) -> ExpandElementTyped<bool> {
        unary_expand(scope, x.into(), Operator::Not).into()
    }
}

pub mod neg {
    use super::*;

    pub fn expand<E: CubePrimitive>(
        scope: &mut Scope,
        x: ExpandElementTyped<E>,
    ) -> ExpandElementTyped<E> {
        unary_expand(scope, x.into(), Arithmetic::Neg).into()
    }
}

macro_rules! impl_unary_func {
    ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => {
        pub trait $trait_name: CubePrimitive + Sized {
            #[allow(unused_variables)]
            fn $method_name(x: Self) -> Self {
                unexpanded!()
            }

            fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<Self> {
                unary_expand(scope, x.into(), $operator).into()
            }
        }

        $(impl $trait_name for $type {})*
    }
}

impl Exp for f32 {
    fn exp(x: Self) -> Self {
        x.exp()
    }
}

macro_rules! impl_unary_func_fixed_out_vectorization {
    ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
        pub trait $trait_name: CubePrimitive + Sized {
            #[allow(unused_variables)]
            fn $method_name(x: Self) -> Self {
                unexpanded!()
            }

            fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<Self> {
                let expand_element: ExpandElement = x.into();
                let item = expand_element.ty.line($out_vectorization);
                unary_expand_fixed_output(scope, expand_element, item, $operator).into()
            }
        }

        $(impl $trait_name for $type {})*
    }
}

macro_rules! impl_unary_func_fixed_out_ty {
    ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $out_ty: ty, $operator:expr, $($type:ty),*) => {
        pub trait $trait_name: CubePrimitive + Sized {
            #[allow(unused_variables)]
            fn $method_name(x: Self) -> $out_ty {
                unexpanded!()
            }

            fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<$out_ty> {
                let expand_element: ExpandElement = x.into();
                let item = Type::new(<$out_ty as CubePrimitive>::as_type(scope)).line(expand_element.ty.line_size());
                unary_expand_fixed_output(scope, expand_element, item, $operator).into()
            }
        }

        $(impl $trait_name for $type {})*
    }
}

impl_unary_func!(
    Abs,
    abs,
    __expand_abs,
    Arithmetic::Abs,
    e2m1,
    e4m3,
    e5m2,
    ue8m0,
    f16,
    bf16,
    flex32,
    tf32,
    f32,
    f64,
    i8,
    i16,
    i32,
    i64,
    u8,
    u16,
    u32,
    u64
);
impl_unary_func!(
    Exp,
    exp,
    __expand_exp,
    Arithmetic::Exp,
    f16,
    bf16,
    flex32,
    tf32,
    // f32,
    f64
);
impl_unary_func!(
    Log,
    log,
    __expand_log,
    Arithmetic::Log,
    f16,
    bf16,
    flex32,
    tf32,
    f32,
    f64
);
impl_unary_func!(
    Log1p,
    log1p,
    __expand_log1p,
    Arithmetic::Log1p,
    f16,
    bf16,
    flex32,
    tf32,
    f32,
    f64
);
impl_unary_func!(
    Cos,
    cos,
    __expand_cos,
    Arithmetic::Cos,
    f16,
    bf16,
    flex32,
    tf32,
    f32,
    f64
);
impl_unary_func!(
    Sin,
    sin,
    __expand_sin,
    Arithmetic::Sin,
    f16,
    bf16,
    flex32,
    tf32,
    f32,
    f64
);
impl_unary_func!(
    Tanh,
    tanh,
    __expand_tanh,
    Arithmetic::Tanh,
    f16,
    bf16,
    flex32,
    tf32,
    f32,
    f64
);
impl_unary_func!(
    Sqrt,
    sqrt,
    __expand_sqrt,
    Arithmetic::Sqrt,
    f16,
    bf16,
    flex32,
    tf32,
    f32,
    f64
);
impl_unary_func!(
    Round,
    round,
    __expand_round,
    Arithmetic::Round,
    f16,
    bf16,
    flex32,
    tf32,
    f32,
    f64
);
impl_unary_func!(
    Floor,
    floor,
    __expand_floor,
    Arithmetic::Floor,
    f16,
    bf16,
    flex32,
    tf32,
    f32,
    f64
);
impl_unary_func!(
    Ceil,
    ceil,
    __expand_ceil,
    Arithmetic::Ceil,
    f16,
    bf16,
    flex32,
    tf32,
    f32,
    f64
);
impl_unary_func!(
    Trunc,
    trunc,
    __expand_trunc,
    Arithmetic::Trunc,
    f16,
    bf16,
    flex32,
    tf32,
    f32,
    f64
);
impl_unary_func!(
    Erf,
    erf,
    __expand_erf,
    Arithmetic::Erf,
    f16,
    bf16,
    flex32,
    tf32,
    f32,
    f64
);
impl_unary_func!(
    Recip,
    recip,
    __expand_recip,
    Arithmetic::Recip,
    f16,
    bf16,
    flex32,
    tf32,
    f32,
    f64
);
impl_unary_func_fixed_out_vectorization!(
    Magnitude,
    magnitude,
    __expand_magnitude,
    Arithmetic::Magnitude,
    0,
    f16,
    bf16,
    flex32,
    tf32,
    f32,
    f64
);
impl_unary_func!(
    Normalize,
    normalize,
    __expand_normalize,
    Arithmetic::Normalize,
    f16,
    bf16,
    flex32,
    tf32,
    f32,
    f64
);
impl_unary_func_fixed_out_ty!(
    CountOnes,
    count_ones,
    __expand_count_ones,
    u32,
    Bitwise::CountOnes,
    u8,
    i8,
    u16,
    i16,
    u32,
    i32,
    u64,
    i64
);
impl_unary_func!(
    ReverseBits,
    reverse_bits,
    __expand_reverse_bits,
    Bitwise::ReverseBits,
    u8,
    i8,
    u16,
    i16,
    u32,
    i32,
    u64,
    i64
);

impl_unary_func!(
    BitwiseNot,
    bitwise_not,
    __expand_bitwise_not,
    Bitwise::BitwiseNot,
    u8,
    i8,
    u16,
    i16,
    u32,
    i32,
    u64,
    i64
);
impl_unary_func_fixed_out_ty!(
    LeadingZeros,
    leading_zeros,
    __expand_leading_zeros,
    u32,
    Bitwise::LeadingZeros,
    u8,
    i8,
    u16,
    i16,
    u32,
    i32,
    u64,
    i64
);
impl_unary_func_fixed_out_ty!(
    FindFirstSet,
    find_first_set,
    __expand_find_first_set,
    u32,
    Bitwise::FindFirstSet,
    u8,
    i8,
    u16,
    i16,
    u32,
    i32,
    u64,
    i64
);
impl_unary_func_fixed_out_ty!(
    IsNan,
    is_nan,
    __expand_is_nan,
    bool,
    Comparison::IsNan,
    f16,
    bf16,
    flex32,
    tf32,
    f32,
    f64
);
impl_unary_func_fixed_out_ty!(
    IsInf,
    is_inf,
    __expand_is_inf,
    bool,
    Comparison::IsInf,
    f16,
    bf16,
    flex32,
    tf32,
    f32,
    f64
);