burn-cubecl 0.21.0-pre.4

Generic backend that can be compiled just-in-time to any shader language target
Documentation
use burn_backend::Shape;
use cubecl::prelude::SequenceArg;
use cubecl::{
    ir::{UIntKind, VectorSize},
    prelude::*,
    std::{
        FastDivmod, FastDivmodInt,
        tensor::layout::linear::{LinearLayoutLaunch, LinearViewLayoutLaunch},
    },
};

use crate::{CubeRuntime, tensor::CubeTensor};

pub fn shape_divmod<R: CubeRuntime>(tensor: &CubeTensor<R>) -> SequenceArg<R, FastDivmod<usize>> {
    let mut arg = SequenceArg::new();
    for dim in tensor.meta.shape().iter() {
        arg.push(*dim);
    }
    arg
}

pub fn linear_layout<R: CubeRuntime>(
    tensor: &CubeTensor<R>,
    vector_size: VectorSize,
) -> LinearLayoutLaunch<R> {
    LinearLayoutLaunch::from_shape_strides(
        tensor.meta.shape().clone(),
        tensor.meta.strides().clone(),
        // Don't care about type size, only vector size
        Type::new(UIntKind::U32.into()).with_vector_size(vector_size),
        LinearViewLayoutLaunch::new(),
    )
}

pub fn split_dim<R: CubeRuntime>(
    mut tensor: CubeTensor<R>,
    dim: usize,
    shape: &[usize],
) -> CubeTensor<R> {
    let mut stride = tensor.meta.strides()[dim];
    tensor.meta.remove(dim);

    for size in shape.iter().rev() {
        tensor.meta.insert(dim, *size, stride);
        stride *= size;
    }

    tensor
}

pub fn broadcast_shape<R: CubeRuntime>(tensors: &[&CubeTensor<R>]) -> Shape {
    let rank = tensors[0].meta.num_dims();
    debug_assert!(
        tensors.iter().all(|it| it.meta.num_dims() == rank),
        "Broadcast tensors must have the same rank"
    );

    let dims = (0..rank).map(|dim| {
        let max = tensors.iter().map(|it| it.meta.shape()[dim]).max();
        let max = max.unwrap_or(1);
        debug_assert!(
            tensors
                .iter()
                .all(|it| it.meta.shape()[dim] == max || it.meta.shape()[dim] == 1),
            "Broadcast dims must be size 1"
        );
        max
    });

    Shape::from(dims)
}

pub fn broadcast_strides<R: CubeRuntime>(
    reference: &CubeTensor<R>,
    tensor: &CubeTensor<R>,
) -> SequenceArg<R, usize> {
    if reference.meta.shape() != tensor.meta.shape() {
        tensor
            .meta
            .strides()
            .iter()
            .zip(
                tensor
                    .meta
                    .shape()
                    .iter()
                    .zip(reference.meta.shape().iter()),
            )
            .map(|(stride, (shape, ref_shape))| if *shape == *ref_shape { *stride } else { 0 })
            .collect()
    } else {
        tensor.meta.strides().iter().copied().collect()
    }
}

#[cube]
pub(crate) fn decompose_linear<I: FastDivmodInt>(
    pos: I,
    shape: &Sequence<FastDivmod<I>>,
) -> (I, Sequence<I>) {
    let rank = comptime![shape.len()];
    let mut offs = pos;
    let mut out = Sequence::new();

    #[unroll]
    for i in 0..rank {
        let dim = comptime![rank - i - 1];
        let (rem, offs_local) = shape.index(dim).div_mod(offs);
        out.push(offs_local);
        offs = rem;
    }

    (offs, out.rev())
}

pub(crate) trait RequiredAddrType {
    fn required_address_type(&self) -> AddressType;
}

impl<R: CubeRuntime> RequiredAddrType for CubeTensor<R> {
    fn required_address_type(&self) -> AddressType {
        self.required_address_type()
    }
}
impl<R: CubeRuntime> RequiredAddrType for Option<CubeTensor<R>> {
    fn required_address_type(&self) -> AddressType {
        self.as_ref()
            .map(|it| it.required_address_type())
            .unwrap_or_default()
    }
}

macro_rules! address_type {
    ($($tensor: tt),*) => {
        [$($crate::kernel::utils::RequiredAddrType::required_address_type(&$tensor)),*]
        .into_iter()
        .max()
        .unwrap_or_default()
    };
}
pub(crate) use address_type;