cubecl-core 0.2.0

CubeCL core create
Documentation
use crate::frontend::operation::base::cmp_expand;
use crate::frontend::{CubeContext, ExpandElementTyped, UInt, BF16, F16, F32, F64, I32, I64};
use crate::ir::Operator;
use crate::prelude::CubePrimitive;

macro_rules! impl_cmp {
    ({ $($type:ty| $($rhs:ty);*),* }) => {
        $(
            $(
                impl core::cmp::PartialEq<$rhs> for $type {
                    fn eq(&self, rhs: &$rhs) -> bool {
                        let rhs: Self = (*rhs).into();
                        self == &rhs
                    }
                }

                impl core::cmp::PartialOrd<$rhs> for $type {
                    fn partial_cmp(&self, rhs: &$rhs) -> Option<std::cmp::Ordering> {
                        let rhs: Self = (*rhs).into();
                        core::cmp::PartialOrd::partial_cmp(self, &rhs)
                    }
                }

            )*

            impl_cmp!($type);
        )*
    };
    ($type:ty) => {
        impl core::cmp::PartialEq for $type {
            fn eq(&self, other: &Self) -> bool {
                self.val == other.val && self.vectorization == other.vectorization
            }
        }

        impl core::cmp::Eq for $type {}

        impl core::cmp::PartialOrd for $type {
            fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
                match self.val.partial_cmp(&other.val) {
                    Some(core::cmp::Ordering::Equal) => {}
                    ord => return ord,
                }
                self.vectorization.partial_cmp(&other.vectorization)
            }
        }
    };
}

impl_cmp!(
    {
        F16 | f32;u32,
        F32 | f32;u32,
        BF16 | f32;u32,
        F64 | f32;u32,
        I32 | i32;u32,
        I64 | i32;u32,
        UInt | u32
    }
);

pub mod ne {

    use super::*;

    pub fn expand<C: CubePrimitive>(
        context: &mut CubeContext,
        lhs: ExpandElementTyped<C>,
        rhs: ExpandElementTyped<C>,
    ) -> ExpandElementTyped<bool> {
        cmp_expand(context, lhs.into(), rhs.into(), Operator::NotEqual).into()
    }
}

pub mod gt {
    use super::*;

    pub fn expand<C: CubePrimitive>(
        context: &mut CubeContext,
        lhs: ExpandElementTyped<C>,
        rhs: ExpandElementTyped<C>,
    ) -> ExpandElementTyped<bool> {
        cmp_expand(context, lhs.into(), rhs.into(), Operator::Greater).into()
    }
}

pub mod lt {
    use super::*;

    pub fn expand<C: CubePrimitive>(
        context: &mut CubeContext,
        lhs: ExpandElementTyped<C>,
        rhs: ExpandElementTyped<C>,
    ) -> ExpandElementTyped<bool> {
        cmp_expand(context, lhs.into(), rhs.into(), Operator::Lower).into()
    }
}

pub mod ge {
    use super::*;

    pub fn expand<C: CubePrimitive>(
        context: &mut CubeContext,
        lhs: ExpandElementTyped<C>,
        rhs: ExpandElementTyped<C>,
    ) -> ExpandElementTyped<bool> {
        cmp_expand(context, lhs.into(), rhs.into(), Operator::GreaterEqual).into()
    }
}

pub mod le {
    use super::*;

    pub fn expand<C: CubePrimitive>(
        context: &mut CubeContext,
        lhs: ExpandElementTyped<C>,
        rhs: ExpandElementTyped<C>,
    ) -> ExpandElementTyped<bool> {
        cmp_expand(context, lhs.into(), rhs.into(), Operator::LowerEqual).into()
    }
}

pub mod eq {

    use super::*;

    pub fn expand<C: CubePrimitive>(
        context: &mut CubeContext,
        lhs: ExpandElementTyped<C>,
        rhs: ExpandElementTyped<C>,
    ) -> ExpandElementTyped<bool> {
        cmp_expand(context, lhs.into(), rhs.into(), Operator::Equal).into()
    }
}

pub mod add_assign {
    use super::*;

    pub fn expand<C: CubePrimitive>(
        context: &mut CubeContext,
        lhs: ExpandElementTyped<C>,
        rhs: ExpandElementTyped<C>,
    ) -> ExpandElementTyped<C> {
        cmp_expand(context, lhs.into(), rhs.into(), Operator::Add).into()
    }
}