cubecl-std 0.10.0-pre.3

CubeCL Standard Library.
Documentation
use cubecl::prelude::*;
use cubecl_core::{self as cubecl, unexpanded};
use variadics_please::all_tuples;

use crate::tensor::{
    launch::{BufferArg, ViewLayoutLaunchArg},
    layout::*,
};

/// Coordinates that can be converted to a dynamic sequence of signed coordinates.
/// Can be used to convert any set of coordinates to a comptime-sized sequence for use with TMA.
#[cube]
pub trait IntoDyn: Coordinates + LaunchArg {
    fn into_dyn(self) -> Sequence<i32> {
        unexpanded!()
    }
}

macro_rules! impl_tuple {
    ($(($T: ident, $t: ident)),*) => {
        impl<$($T: Coordinates + CubePrimitive + LaunchArg),*> IntoDyn for ($($T),*) {}

        impl<$($T: Coordinates + CubePrimitive + LaunchArg),*> IntoDynExpand for ($(NativeExpand<$T>),*) {
            fn __expand_into_dyn_method(self, scope: &mut Scope) -> SequenceExpand<i32> {
                let mut seq = Sequence::__expand_new(scope);
                let ($($t),*) = self;
                let ($($t),*) = ($(i32::__expand_cast_from(scope, $t)),*);
                $(seq.__expand_push_method(scope, $t);)*
                seq
            }
        }
    };
}

all_tuples!(impl_tuple, 2, 12, T, t);

#[cube]
impl IntoDyn for Sequence<i32> {
    fn into_dyn(self) -> Sequence<i32> {
        self
    }
}

#[cube]
impl IntoDyn for Sequence<u32> {
    fn into_dyn(self) -> Sequence<i32> {
        let mut seq = Sequence::new();
        for x in self {
            seq.push(i32::cast_from(x));
        }
        seq
    }
}

#[derive(CubeType)]
pub struct IntoDynLayout<L: Layout<SourceCoordinates: IntoDyn> + ViewLayoutLaunchArg> {
    layout: L,
}

impl<L: Layout<SourceCoordinates: IntoDyn> + ViewLayoutLaunchArg> ViewLayoutLaunchArg
    for IntoDynLayout<L>
{
    type RuntimeArg<R: Runtime> = L::RuntimeArg<R>;
    type CompilationArg = L::CompilationArg;

    fn register<R: Runtime, B: BufferArg>(
        arg: Self::RuntimeArg<R>,
        buffer: &B,
        ty: Type,
        launcher: &mut KernelLauncher<R>,
    ) -> Self::CompilationArg {
        L::register::<R, B>(arg, buffer, ty, launcher)
    }
    fn expand(
        arg: &Self::CompilationArg,
        ty: Type,
        builder: &mut KernelBuilder,
    ) -> <Self as CubeType>::ExpandType {
        IntoDynLayoutExpand {
            layout: L::expand(arg, ty, builder),
        }
    }
    fn expand_output(
        arg: &Self::CompilationArg,
        ty: Type,
        builder: &mut KernelBuilder,
    ) -> <Self as CubeType>::ExpandType {
        IntoDynLayoutExpand {
            layout: L::expand_output(arg, ty, builder),
        }
    }
}

#[derive(CubeType)]
pub struct IntoDyn2Layout<
    L: Layout<SourceCoordinates = (P, O)> + ViewLayoutLaunchArg,
    P: IntoDyn,
    O: IntoDyn,
> {
    layout: L,
}

impl<L: Layout<SourceCoordinates = (P, O)> + ViewLayoutLaunchArg, P: IntoDyn, O: IntoDyn>
    ViewLayoutLaunchArg for IntoDyn2Layout<L, P, O>
{
    type RuntimeArg<R: Runtime> = L::RuntimeArg<R>;
    type CompilationArg = L::CompilationArg;

    fn register<R: Runtime, B: BufferArg>(
        arg: Self::RuntimeArg<R>,
        buffer: &B,
        ty: Type,
        launcher: &mut KernelLauncher<R>,
    ) -> Self::CompilationArg {
        L::register::<R, B>(arg, buffer, ty, launcher)
    }
    fn expand(
        arg: &Self::CompilationArg,
        ty: Type,
        builder: &mut KernelBuilder,
    ) -> <Self as CubeType>::ExpandType {
        IntoDyn2LayoutExpand {
            layout: L::expand(arg, ty, builder),
        }
    }
    fn expand_output(
        arg: &Self::CompilationArg,
        ty: Type,
        builder: &mut KernelBuilder,
    ) -> <Self as CubeType>::ExpandType {
        IntoDyn2LayoutExpand {
            layout: L::expand_output(arg, ty, builder),
        }
    }
}

impl<L: Layout<SourceCoordinates: IntoDyn> + ViewLayoutLaunchArg> IntoDynLayout<L> {
    pub fn new(layout: L) -> Self {
        IntoDynLayout { layout }
    }
}

impl<
    L: Layout<SourceCoordinates = (P, O)> + ViewLayoutLaunchArg,
    P: IntoDyn,
    O: IntoDyn + ViewLayoutLaunchArg,
> IntoDyn2Layout<L, P, O>
{
    pub fn new(layout: L) -> Self {
        IntoDyn2Layout { layout }
    }
}

#[cube]
impl<L: Layout<SourceCoordinates: IntoDyn> + ViewLayoutLaunchArg> Layout for IntoDynLayout<L> {
    type Coordinates = L::Coordinates;
    type SourceCoordinates = Sequence<i32>;

    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
        let pos = self.layout.to_source_pos(pos);
        pos.into_dyn()
    }

    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
        self.layout.is_in_bounds(pos)
    }

    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
        let (pos, in_bounds) = self.layout.to_source_pos_checked(pos);
        (pos.into_dyn(), in_bounds)
    }

    fn shape(&self) -> Self::Coordinates {
        self.layout.shape()
    }
}

#[cube]
impl<
    L: Layout<SourceCoordinates = (P, O)> + ViewLayoutLaunchArg,
    P: IntoDyn,
    O: IntoDyn + ViewLayoutLaunchArg,
> Layout for IntoDyn2Layout<L, P, O>
{
    type Coordinates = L::Coordinates;
    type SourceCoordinates = (Sequence<i32>, Sequence<i32>);

    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
        let pos = self.layout.to_source_pos(pos);
        (pos.0.into_dyn(), pos.1.into_dyn())
    }

    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
        self.layout.is_in_bounds(pos)
    }

    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
        let (pos, in_bounds) = self.layout.to_source_pos_checked(pos);
        ((pos.0.into_dyn(), pos.1.into_dyn()), in_bounds)
    }

    fn shape(&self) -> Self::Coordinates {
        self.layout.shape()
    }
}