cubecl-core 0.2.0

CubeCL core create
Documentation
use crate::frontend::{Array, CubeContext, ExpandElement, SharedMemory, Tensor, UInt};
use crate::frontend::{BF16, F16, F32, F64, I32, I64};
use crate::{ir, unexpanded};

macro_rules! impl_op_assign {
    (($tr:ident|$func:ident) => { $($type:ty| $($rhs:ty);*),* }) => {
        $(
            $(
                impl $tr<$rhs> for $type {
                    fn $func(&mut self, _rhs: $rhs) {
                        unexpanded!()
                    }
                }
            )*

            impl $tr for $type {
                fn $func(&mut self, _rhs: Self) {
                    unexpanded!()
                }
            }
        )*
    };
}

pub mod assign {
    use self::ir::{Operator, UnaryOperator};

    use super::*;

    pub fn expand<I: Into<ExpandElement>, O: Into<ExpandElement>>(
        context: &mut CubeContext,
        input: I,
        output: O,
    ) {
        context.register(Operator::Assign(UnaryOperator {
            input: *input.into(),
            out: *output.into(),
        }));
    }
}

pub mod index_assign {
    use crate::{
        frontend::CubeType,
        prelude::{ExpandElementTyped, SliceMut},
        unexpanded,
    };

    use self::ir::{BinaryOperator, Operator, Variable};

    use super::*;

    pub fn expand<A: CubeType + core::ops::Index<UInt>>(
        context: &mut CubeContext,
        array: ExpandElementTyped<A>,
        index: ExpandElementTyped<UInt>,
        value: ExpandElementTyped<A::Output>,
    ) where
        A::Output: CubeType + Sized,
    {
        let index: Variable = index.expand.into();
        let index = match index {
            Variable::ConstantScalar(value) => {
                Variable::ConstantScalar(ir::ConstantScalarValue::UInt(value.as_u64()))
            }
            _ => index,
        };
        context.register(Operator::IndexAssign(BinaryOperator {
            lhs: index,
            rhs: value.expand.into(),
            out: array.expand.into(),
        }));
    }

    macro_rules! impl_index {
        ($type:ident) => {
            impl<E: CubeType, I: Into<UInt>> core::ops::IndexMut<I> for $type<E> {
                fn index_mut(&mut self, _index: I) -> &mut Self::Output {
                    unexpanded!()
                }
            }
        };
    }
    macro_rules! impl_index_vec {
        ($($type:ident),*) => {
            $(
                impl core::ops::IndexMut<UInt> for $type {
                    fn index_mut(&mut self, _index: UInt) -> &mut Self::Output {
                        unexpanded!()
                    }
                }
                impl core::ops::IndexMut<u32> for $type {
                    fn index_mut(&mut self, _index: u32) -> &mut Self::Output {
                        unexpanded!()
                    }
                }

            )*
        };
    }

    impl_index!(Array);
    impl_index!(Tensor);
    impl_index!(SharedMemory);
    impl_index_vec!(I64, I32, F16, BF16, F32, F64, UInt);

    impl<'a, E: CubeType, I: Into<UInt>> core::ops::IndexMut<I> for SliceMut<'a, E> {
        fn index_mut(&mut self, _index: I) -> &mut Self::Output {
            unexpanded!()
        }
    }
}

pub mod index {
    use crate::{
        frontend::{
            operation::base::{binary_expand, binary_expand_no_vec},
            CubeType,
        },
        prelude::{ExpandElementTyped, Slice, SliceMut},
        unexpanded,
    };

    use self::ir::{Operator, Variable};

    use super::*;

    pub fn expand<A: CubeType + core::ops::Index<UInt>>(
        context: &mut CubeContext,
        array: ExpandElementTyped<A>,
        index: ExpandElementTyped<UInt>,
    ) -> ExpandElementTyped<A::Output>
    where
        A::Output: CubeType + Sized,
    {
        let index: ExpandElement = index.into();
        let index_var: Variable = *index;
        let index = match index_var {
            Variable::ConstantScalar(value) => ExpandElement::Plain(Variable::ConstantScalar(
                ir::ConstantScalarValue::UInt(value.as_u64()),
            )),
            _ => index,
        };
        let array: ExpandElement = array.into();
        let var: Variable = *array;
        let var = match var {
            Variable::Local { .. } => binary_expand_no_vec(context, array, index, Operator::Index),
            _ => binary_expand(context, array, index, Operator::Index),
        };

        ExpandElementTyped::new(var)
    }

    macro_rules! impl_index {
        ($type:ident) => {
            impl<E: CubeType, I: Into<UInt>> core::ops::Index<I> for $type<E> {
                type Output = E;

                fn index(&self, _index: I) -> &Self::Output {
                    unexpanded!()
                }
            }
        };
    }

    macro_rules! impl_index_vec {
        ($($type:ident),*) => {
            $(
                impl core::ops::Index<UInt> for $type {
                    type Output = Self;

                    fn index(&self, _index: UInt) -> &Self::Output {
                        unexpanded!()
                    }
                }

                impl core::ops::Index<u32> for $type {
                    type Output = Self;

                    fn index(&self, _index: u32) -> &Self::Output {
                        unexpanded!()
                    }
                }
            )*
        };
    }

    impl_index!(Array);
    impl_index!(Tensor);
    impl_index!(SharedMemory);

    impl_index_vec!(I64, I32, F16, BF16, F32, F64, UInt);

    impl<'a, E: CubeType, I: Into<UInt>> core::ops::Index<I> for SliceMut<'a, E> {
        type Output = E;
        fn index(&self, _index: I) -> &Self::Output {
            unexpanded!()
        }
    }

    impl<'a, E: CubeType, I: Into<UInt>> core::ops::Index<I> for Slice<'a, E> {
        type Output = E;
        fn index(&self, _index: I) -> &Self::Output {
            unexpanded!()
        }
    }
}

pub mod add_assign_array_op {
    use self::ir::Operator;
    use super::*;
    use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};

    pub fn expand<A: CubeType + core::ops::Index<UInt>>(
        context: &mut CubeContext,
        array: ExpandElementTyped<A>,
        index: ExpandElementTyped<UInt>,
        value: ExpandElementTyped<A::Output>,
    ) where
        A::Output: CubeType + Sized,
    {
        array_assign_binary_op_expand(context, array, index, value, Operator::Add);
    }
}

pub mod sub_assign_array_op {
    use self::ir::Operator;
    use super::*;
    use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};

    pub fn expand<A: CubeType + core::ops::Index<UInt>>(
        context: &mut CubeContext,
        array: ExpandElementTyped<A>,
        index: ExpandElementTyped<UInt>,
        value: ExpandElementTyped<A::Output>,
    ) where
        A::Output: CubeType + Sized,
    {
        array_assign_binary_op_expand(context, array, index, value, Operator::Sub);
    }
}

pub mod mul_assign_array_op {
    use self::ir::Operator;
    use super::*;
    use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};

    pub fn expand<A: CubeType + core::ops::Index<UInt>>(
        context: &mut CubeContext,
        array: ExpandElementTyped<A>,
        index: ExpandElementTyped<UInt>,
        value: ExpandElementTyped<A::Output>,
    ) where
        A::Output: CubeType + Sized,
    {
        array_assign_binary_op_expand(context, array, index, value, Operator::Mul);
    }
}

pub mod div_assign_array_op {
    use self::ir::Operator;
    use super::*;
    use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};

    pub fn expand<A: CubeType + core::ops::Index<UInt>>(
        context: &mut CubeContext,
        array: ExpandElementTyped<A>,
        index: ExpandElementTyped<UInt>,
        value: ExpandElementTyped<A::Output>,
    ) where
        A::Output: CubeType + Sized,
    {
        array_assign_binary_op_expand(context, array, index, value, Operator::Div);
    }
}

pub mod add_assign_op {
    use crate::frontend::{operation::base::assign_op_expand, BF16, F16, F32, F64, I32, I64};
    use core::ops::AddAssign;

    use self::ir::Operator;

    use super::*;

    pub fn expand<L: Into<ExpandElement>, R: Into<ExpandElement>>(
        context: &mut CubeContext,
        lhs: L,
        rhs: R,
    ) -> ExpandElement {
        assign_op_expand(context, lhs.into(), rhs.into(), Operator::Add)
    }

    impl_op_assign!(
        (AddAssign|add_assign) => {
            F16 | f32;u32,
            F32 | f32;u32,
            BF16 | f32;u32,
            F64 | f32;u32,
            I32 | i32;u32,
            I64 | i32;u32,
            UInt | u32
        }
    );
}

pub mod sub_assign_op {
    use self::ir::Operator;
    use super::*;
    use crate::frontend::{operation::base::assign_op_expand, BF16, F16, F32, F64, I32, I64};
    use core::ops::SubAssign;

    pub fn expand<L: Into<ExpandElement>, R: Into<ExpandElement>>(
        context: &mut CubeContext,
        lhs: L,
        rhs: R,
    ) -> ExpandElement {
        assign_op_expand(context, lhs.into(), rhs.into(), Operator::Sub)
    }

    impl_op_assign!(
        (SubAssign|sub_assign) => {
            F16 | f32;u32,
            F32 | f32;u32,
            BF16 | f32;u32,
            F64 | f32;u32,
            I32 | i32;u32,
            I64 | i32;u32,
            UInt | u32
        }
    );
}

pub mod mul_assign_op {
    use self::ir::Operator;
    use super::*;
    use crate::frontend::{operation::base::assign_op_expand, BF16, F16, F32, F64, I32, I64};
    use core::ops::MulAssign;

    pub fn expand<L: Into<ExpandElement>, R: Into<ExpandElement>>(
        context: &mut CubeContext,
        lhs: L,
        rhs: R,
    ) -> ExpandElement {
        assign_op_expand(context, lhs.into(), rhs.into(), Operator::Mul)
    }

    impl_op_assign!(
        (MulAssign|mul_assign) => {
            F16 | f32;u32,
            F32 | f32;u32,
            BF16 | f32;u32,
            F64 | f32;u32,
            I32 | i32;u32,
            I64 | i32;u32,
            UInt | u32
        }
    );
}

pub mod div_assign_op {
    use self::ir::Operator;
    use super::*;
    use crate::frontend::{operation::base::assign_op_expand, BF16, F16, F32, F64, I32, I64};
    use core::ops::DivAssign;

    pub fn expand<L: Into<ExpandElement>, R: Into<ExpandElement>>(
        context: &mut CubeContext,
        lhs: L,
        rhs: R,
    ) -> ExpandElement {
        assign_op_expand(context, lhs.into(), rhs.into(), Operator::Div)
    }

    impl_op_assign!(
        (DivAssign|div_assign) => {
            F16 | f32;u32,
            F32 | f32;u32,
            BF16 | f32;u32,
            F64 | f32;u32,
            I32 | i32;u32,
            I64 | i32;u32,
            UInt | u32
        }
    );
}