cubecl-core 0.10.0-pre.3

CubeCL core create
Documentation
use super::{ReadOnly, ReadWrite, Slice, SliceExpand, SliceOriginExpand, SliceVisibility};
use crate as cubecl;
use crate::{ir::Scope, prelude::*, unexpanded};
use cubecl_common::tf32;
use cubecl_ir::ManagedVariable;

pub(crate) fn is_tf32<C: CubePrimitive, T: CubePrimitive>(scope: &mut Scope) -> bool {
    let ty_c = C::as_type(scope).storage_type();
    let ty_t = T::as_type(scope).storage_type();
    let ty_f32 = f32::as_type(scope).storage_type();
    let ty_tf32 = tf32::as_type(scope).storage_type();

    (ty_c == ty_f32 && ty_t == ty_tf32) || (ty_c == ty_tf32 && ty_t == ty_f32)
}

impl<E: CubePrimitive> SliceOperator<E> for SharedMemory<E> {}
impl<E: CubePrimitive> SliceOperatorExpand<E> for NativeExpand<SharedMemory<E>> {
    fn __expand_slice_method(
        &self,
        scope: &mut Scope,
        start: NativeExpand<usize>,
        end: NativeExpand<usize>,
    ) -> SliceExpand<E, ReadOnly> {
        Slice::__expand_new(
            scope,
            SliceOriginExpand::SharedMemory(self.clone()),
            start,
            end,
        )
    }

    fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
        let len = expand_length_native(scope, *self.expand);

        Slice::__expand_new(
            scope,
            SliceOriginExpand::SharedMemory(self.clone()),
            0usize.into(),
            ManagedVariable::Plain(len).into(),
        )
    }
}

impl<E: CubePrimitive> SliceMutOperator<E> for SharedMemory<E> {}
impl<E: CubePrimitive> SliceMutOperatorExpand<E> for NativeExpand<SharedMemory<E>> {
    fn __expand_slice_mut_method(
        &self,
        scope: &mut Scope,
        start: NativeExpand<usize>,
        end: NativeExpand<usize>,
    ) -> SliceExpand<E, ReadWrite> {
        Slice::__expand_new(
            scope,
            SliceOriginExpand::SharedMemory(self.clone()),
            start,
            end,
        )
    }

    fn __expand_to_slice_mut_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
        let len = expand_length_native(scope, *self.expand);

        Slice::__expand_new(
            scope,
            SliceOriginExpand::SharedMemory(self.clone()),
            0usize.into(),
            ManagedVariable::Plain(len).into(),
        )
    }
}

impl<E: CubePrimitive> SliceOperator<E> for Tensor<E> {}
impl<E: CubePrimitive> SliceOperatorExpand<E> for NativeExpand<Tensor<E>> {
    fn __expand_slice_method(
        &self,
        scope: &mut Scope,
        start: NativeExpand<usize>,
        end: NativeExpand<usize>,
    ) -> SliceExpand<E, ReadOnly> {
        Slice::__expand_new(scope, SliceOriginExpand::Tensor(self.clone()), start, end)
    }

    fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
        let len = self.clone().__expand_len_method(scope);
        Slice::__expand_new(
            scope,
            SliceOriginExpand::Tensor(self.clone()),
            0usize.into(),
            len,
        )
    }
}

impl<E: CubePrimitive> SliceMutOperator<E> for Tensor<E> {}
impl<E: CubePrimitive> SliceMutOperatorExpand<E> for NativeExpand<Tensor<E>> {
    fn __expand_slice_mut_method(
        &self,
        scope: &mut Scope,
        start: NativeExpand<usize>,
        end: NativeExpand<usize>,
    ) -> SliceExpand<E, ReadWrite> {
        Slice::__expand_new(scope, SliceOriginExpand::Tensor(self.clone()), start, end)
    }

    fn __expand_to_slice_mut_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
        let len = self.clone().__expand_len_method(scope);
        Slice::__expand_new(
            scope,
            SliceOriginExpand::Tensor(self.clone()),
            0usize.into(),
            len,
        )
    }
}

impl<E: CubePrimitive> SliceOperator<E> for Array<E> {}
impl<E: CubePrimitive> SliceOperatorExpand<E> for NativeExpand<Array<E>> {
    fn __expand_slice_method(
        &self,
        scope: &mut Scope,
        start: NativeExpand<usize>,
        end: NativeExpand<usize>,
    ) -> SliceExpand<E, ReadOnly> {
        Slice::__expand_new(scope, SliceOriginExpand::Array(self.clone()), start, end)
    }

    fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
        let len = self.clone().__expand_len_method(scope);
        Slice::__expand_new(
            scope,
            SliceOriginExpand::Array(self.clone()),
            0usize.into(),
            len,
        )
    }
}

impl<E: CubePrimitive> SliceMutOperator<E> for Array<E> {}
impl<E: CubePrimitive> SliceMutOperatorExpand<E> for NativeExpand<Array<E>> {
    fn __expand_slice_mut_method(
        &self,
        scope: &mut Scope,
        start: NativeExpand<usize>,
        end: NativeExpand<usize>,
    ) -> SliceExpand<E, ReadWrite> {
        Slice::__expand_new(scope, SliceOriginExpand::Array(self.clone()), start, end)
    }

    fn __expand_to_slice_mut_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
        let len = self.clone().__expand_len_method(scope);
        Slice::__expand_new(
            scope,
            SliceOriginExpand::Array(self.clone()),
            0usize.into(),
            len,
        )
    }
}

impl<E: CubePrimitive, IO: SliceVisibility> SliceOperator<E> for Slice<E, IO> {}
impl<E: CubePrimitive, IO: SliceVisibility> SliceOperatorExpand<E> for SliceExpand<E, IO> {
    fn __expand_slice_method(
        &self,
        scope: &mut Scope,
        start: NativeExpand<usize>,
        end: NativeExpand<usize>,
    ) -> SliceExpand<E, ReadOnly> {
        let length = crate::frontend::sub::expand(scope, end, start.clone());
        let offset = crate::frontend::add::expand(scope, start, self.offset.clone());

        SliceExpand {
            origin: self.origin.clone(),
            io: core::marker::PhantomData,
            offset,
            length,
            vector_size: self.vector_size,
        }
    }

    fn __expand_to_slice_method(&self, _scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
        SliceExpand {
            origin: self.origin.clone(),
            io: core::marker::PhantomData,
            offset: self.offset.clone(),
            length: self.length.clone(),
            vector_size: self.vector_size,
        }
    }
}

impl<E: CubePrimitive> SliceMutOperator<E> for Slice<E, ReadWrite> {}
impl<E: CubePrimitive> SliceMutOperatorExpand<E> for SliceExpand<E, ReadWrite> {
    fn __expand_slice_mut_method(
        &self,
        scope: &mut Scope,
        start: NativeExpand<usize>,
        end: NativeExpand<usize>,
    ) -> SliceExpand<E, ReadWrite> {
        let length = crate::frontend::sub::expand(scope, end, start.clone());
        let offset = crate::frontend::add::expand(scope, start, self.offset.clone());

        SliceExpand {
            origin: self.origin.clone(),
            io: core::marker::PhantomData,
            offset,
            length,
            vector_size: self.vector_size,
        }
    }

    fn __expand_to_slice_mut_method(&self, _scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
        SliceExpand {
            origin: self.origin.clone(),
            io: core::marker::PhantomData,
            offset: self.offset.clone(),
            length: self.length.clone(),
            vector_size: self.vector_size,
        }
    }
}

#[cube(self_type = "ref")]
pub trait SliceOperator<E: CubePrimitive> {
    /// Return a read-only view of all elements comprise between the `start` and `end` indices.
    /// In `checked` mode, if the `end` index is out-of-bound, it is replaced by
    /// the length of `self`.
    #[allow(unused_variables)]
    fn slice(&self, start: usize, end: usize) -> Slice<E, ReadOnly> {
        unexpanded!()
    }

    /// Reinterprete the current type as a read-only slice.
    #[allow(unused_variables)]
    fn to_slice(&self) -> Slice<E, ReadOnly> {
        unexpanded!()
    }
}

#[cube(self_type = "ref")]
pub trait SliceMutOperator<E: CubePrimitive> {
    /// Return a read-write view of all elements comprise between the `start` and `end` indices.
    /// In `checked` mode, if the `end` index is out-of-bound, it is replaced by
    /// the length of `self`.
    #[allow(unused_variables)]
    fn slice_mut(&mut self, start: usize, end: usize) -> Slice<E, ReadWrite> {
        unexpanded!()
    }

    /// Reinterprete the current type as a read-write slice.
    #[allow(unused_variables)]
    fn to_slice_mut(&mut self) -> Slice<E, ReadWrite> {
        unexpanded!()
    }
}

// Automatic implementation for references to SliceOperator.
impl<'a, T: CubePrimitive, L: SliceOperator<T>> SliceOperator<T> for &'a L where
    &'a L: CubeType<ExpandType = L::ExpandType>
{
}

// Automatic implementation for mutable references to SliceOperator.
impl<'a, T: CubePrimitive, L: SliceOperator<T>> SliceOperator<T> for &'a mut L where
    &'a mut L: CubeType<ExpandType = L::ExpandType>
{
}

// Automatic implementation for references to SliceMutOperator.
impl<'a, T: CubePrimitive, L: SliceMutOperator<T>> SliceMutOperator<T> for &'a L where
    &'a L: CubeType<ExpandType = L::ExpandType>
{
}

// Automatic implementation for mutable references to SliceMutOperator.
impl<'a, T: CubePrimitive, L: SliceMutOperator<T>> SliceMutOperator<T> for &'a mut L where
    &'a mut L: CubeType<ExpandType = L::ExpandType>
{
}