burn-cubecl 0.21.0-pre.3

Generic backend that can be compiled just-in-time to any shader language target
Documentation
use crate::{
    CubeRuntime,
    kernel::utils::{address_type, broadcast_shape},
    ops::numeric::empty_device_dtype,
    tensor::CubeTensor,
};
use cubecl::std::tensor::layout::linear::LinearView;
use cubecl::{calculate_cube_count_elemwise, prelude::*};

#[cube(launch_unchecked, address_type = "dynamic")]
fn cross_kernel<E: Float>(
    lhs: &LinearView<E>,
    rhs: &LinearView<E>,
    output: &mut LinearView<E, ReadWrite>,
    #[define(E)] _dtype: StorageType,
) {
    // Each thread processes one 3-element vector
    let vector_idx = ABSOLUTE_POS;
    let base_pos = vector_idx * 3;

    if !output.is_in_bounds(base_pos) {
        terminate!();
    }

    // Extract vectors
    let a0 = lhs[base_pos];
    let a1 = lhs[base_pos + 1];
    let a2 = lhs[base_pos + 2];
    let b0 = rhs[base_pos];
    let b1 = rhs[base_pos + 1];
    let b2 = rhs[base_pos + 2];

    // Compute cross product: a × b
    let x = a1 * b2 - a2 * b1;
    let y = a2 * b0 - a0 * b2;
    let z = a0 * b1 - a1 * b0;

    // Store result
    output[base_pos] = x;
    output[base_pos + 1] = y;
    output[base_pos + 2] = z;
}

pub(crate) fn cross<R: CubeRuntime>(
    lhs: CubeTensor<R>,
    rhs: CubeTensor<R>,
    dim: usize,
) -> CubeTensor<R> {
    let ndims = lhs.meta.num_dims();

    // Validate that the cross dimension has size 3
    if lhs.meta.shape()[dim] != 3 || rhs.meta.shape()[dim] != 3 {
        panic!(
            "Cross product requires dimension {} to have size 3, but got {} and {}",
            dim,
            lhs.meta.shape()[dim],
            rhs.meta.shape()[dim]
        );
    }

    // For now, only support cross on the last dimension
    if dim != ndims - 1 {
        unimplemented!(
            "Cross product on non-last dimension not yet implemented for CubeCL backend"
        );
    }

    let output_shape = broadcast_shape(&[&lhs, &rhs]);

    let output = empty_device_dtype(
        lhs.client.clone(),
        lhs.device.clone(),
        output_shape.clone(),
        lhs.dtype,
    );

    // Number of vectors to process
    let num_vectors = output_shape.num_elements() / 3;

    let cube_dim = CubeDim::new(&lhs.client, num_vectors);
    let cube_count = calculate_cube_count_elemwise(&lhs.client, num_vectors, cube_dim);
    let dtype = lhs.dtype;

    unsafe {
        cross_kernel::launch_unchecked(
            &output.client,
            cube_count,
            cube_dim,
            address_type!(lhs, rhs, output),
            lhs.into_linear_view_like(&output),
            rhs.into_linear_view_like(&output),
            output.clone().into_linear_view(),
            dtype.into(),
        );
    };

    output
}