burn-cubecl 0.21.0-pre.2

Generic backend that can be compiled just-in-time to any shader language target
Documentation
use burn_backend::Shape;
use cubecl::{
    ir::LineSize,
    prelude::*,
    std::{
        FastDivmod, FastDivmodArgs, FastDivmodInt,
        tensor::layout::linear::{LinearLayoutArgs, LinearViewLaunch},
    },
};
use cubecl::{prelude::SequenceArg, std::tensor::layout::linear::LinearLayout};

use crate::{CubeRuntime, tensor::CubeTensor};

pub fn shape_divmod<'a, R: CubeRuntime>(
    tensor: &CubeTensor<R>,
) -> SequenceArg<'a, R, FastDivmod<usize>> {
    let mut arg = SequenceArg::new();
    for dim in tensor.meta.shape().iter() {
        arg.push(FastDivmodArgs::<usize>::new(&tensor.client, *dim));
    }
    arg
}

pub fn linear_layout<'a, R: CubeRuntime>(
    tensor: &'a CubeTensor<R>,
    line_size: LineSize,
) -> LinearLayoutArgs<'a, R> {
    LinearLayoutArgs::from_shape_strides(
        &tensor.client,
        tensor.meta.shape(),
        tensor.meta.strides(),
        line_size,
    )
}

pub fn linear_layout_ref<'a, R: CubeRuntime>(
    tensor: &'a CubeTensor<R>,
    reference: &'a CubeTensor<R>,
    line_size: LineSize,
) -> LinearLayoutArgs<'a, R> {
    LinearLayoutArgs::from_shape_strides_with_reference(
        &tensor.client,
        tensor.meta.shape(),
        reference.meta.shape(),
        tensor.meta.strides(),
        line_size,
    )
}

pub fn linear_view<'a, R: CubeRuntime>(
    tensor: &'a CubeTensor<R>,
    line_size: LineSize,
) -> LinearViewLaunch<'a, R> {
    let len = tensor.meta.num_elements();
    let layout = linear_layout(tensor, line_size);
    let buffer = unsafe {
        ArrayArg::from_raw_parts_and_size(&tensor.handle, len, line_size, tensor.elem_size())
    };
    LinearViewLaunch::new::<LinearLayout>(buffer, layout)
}

pub fn linear_view_ref<'a, R: CubeRuntime>(
    tensor: &'a CubeTensor<R>,
    reference: &'a CubeTensor<R>,
    line_size: LineSize,
) -> LinearViewLaunch<'a, R> {
    let len = tensor.meta.num_elements();
    let layout = linear_layout_ref(tensor, reference, line_size);
    let buffer = unsafe {
        ArrayArg::from_raw_parts_and_size(&tensor.handle, len, line_size, tensor.elem_size())
    };
    LinearViewLaunch::new::<LinearLayout>(buffer, layout)
}

pub fn linear_view_alias<'a, R: CubeRuntime>(
    tensor: &'a CubeTensor<R>,
    line_size: LineSize,
    pos: usize,
) -> LinearViewLaunch<'a, R> {
    let layout = linear_layout(tensor, line_size);
    let buffer = ArrayArg::Alias { input_pos: pos };
    LinearViewLaunch::new::<LinearLayout>(buffer, layout)
}

pub fn split_dim<R: CubeRuntime>(
    mut tensor: CubeTensor<R>,
    dim: usize,
    shape: &[usize],
) -> CubeTensor<R> {
    let mut stride = tensor.meta.strides()[dim];
    tensor.meta.remove(dim);

    for size in shape.iter().rev() {
        tensor.meta.insert(dim, *size, stride);
        stride *= size;
    }

    tensor
}

pub fn broadcast_shape<R: CubeRuntime>(tensors: &[&CubeTensor<R>]) -> Shape {
    let rank = tensors[0].meta.num_dims();
    debug_assert!(
        tensors.iter().all(|it| it.meta.num_dims() == rank),
        "Broadcast tensors must have the same rank"
    );

    let dims = (0..rank).map(|dim| {
        let max = tensors.iter().map(|it| it.meta.shape()[dim]).max();
        let max = max.unwrap_or(1);
        debug_assert!(
            tensors
                .iter()
                .all(|it| it.meta.shape()[dim] == max || it.meta.shape()[dim] == 1),
            "Broadcast dims must be size 1"
        );
        max
    });

    Shape::from(dims)
}

pub fn broadcast_strides<'a, R: CubeRuntime>(
    reference: &CubeTensor<R>,
    tensor: &'a CubeTensor<R>,
) -> SequenceArg<'a, R, usize> {
    if reference.meta.shape() != tensor.meta.shape() {
        tensor
            .meta
            .strides()
            .iter()
            .zip(
                tensor
                    .meta
                    .shape()
                    .iter()
                    .zip(reference.meta.shape().iter()),
            )
            .map(|(stride, (shape, ref_shape))| if *shape == *ref_shape { *stride } else { 0 })
            .map(ScalarArg::new)
            .collect()
    } else {
        tensor
            .meta
            .strides()
            .iter()
            .copied()
            .map(ScalarArg::new)
            .collect()
    }
}

#[cube]
pub(crate) fn decompose_linear<I: FastDivmodInt>(
    pos: I,
    shape: &Sequence<FastDivmod<I>>,
) -> (I, Sequence<I>) {
    let rank = comptime![shape.len()];
    let mut offs = pos;
    let mut out = Sequence::new();

    #[unroll]
    for i in 0..rank {
        let dim = comptime![rank - i - 1];
        let (rem, offs_local) = shape.index(dim).div_mod(offs);
        out.push(offs_local);
        offs = rem;
    }

    (offs, out.rev())
}

pub(crate) trait RequiredAddrType {
    fn required_address_type(&self) -> AddressType;
}

impl<R: CubeRuntime> RequiredAddrType for CubeTensor<R> {
    fn required_address_type(&self) -> AddressType {
        self.required_address_type()
    }
}
impl<R: CubeRuntime> RequiredAddrType for Option<CubeTensor<R>> {
    fn required_address_type(&self) -> AddressType {
        self.as_ref()
            .map(|it| it.required_address_type())
            .unwrap_or_default()
    }
}

macro_rules! address_type {
    ($($tensor: tt),*) => {
        [$($crate::kernel::utils::RequiredAddrType::required_address_type(&$tensor)),*]
        .into_iter()
        .max()
        .unwrap_or_default()
    };
}
pub(crate) use address_type;