use crate::{
CubeRuntime,
kernel::utils::{address_type, broadcast_shape, linear_view, linear_view_ref},
ops::numeric::empty_device_dtype,
tensor::CubeTensor,
};
use cubecl::std::tensor::layout::linear::LinearView;
use cubecl::{calculate_cube_count_elemwise, prelude::*};
#[cube(launch_unchecked, address_type = "dynamic")]
fn cross_kernel<E: Float>(
lhs: &LinearView<Line<E>>,
rhs: &LinearView<Line<E>>,
output: &mut LinearView<Line<E>, ReadWrite>,
#[define(E)] _dtype: StorageType,
) {
let vector_idx = ABSOLUTE_POS;
let base_pos = vector_idx * 3;
if !output.is_in_bounds(base_pos) {
terminate!();
}
let a0 = lhs[base_pos];
let a1 = lhs[base_pos + 1];
let a2 = lhs[base_pos + 2];
let b0 = rhs[base_pos];
let b1 = rhs[base_pos + 1];
let b2 = rhs[base_pos + 2];
let x = a1 * b2 - a2 * b1;
let y = a2 * b0 - a0 * b2;
let z = a0 * b1 - a1 * b0;
output[base_pos] = x;
output[base_pos + 1] = y;
output[base_pos + 2] = z;
}
pub(crate) fn cross<R: CubeRuntime>(
lhs: CubeTensor<R>,
rhs: CubeTensor<R>,
dim: usize,
) -> CubeTensor<R> {
let ndims = lhs.meta.num_dims();
if lhs.meta.shape()[dim] != 3 || rhs.meta.shape()[dim] != 3 {
panic!(
"Cross product requires dimension {} to have size 3, but got {} and {}",
dim,
lhs.meta.shape()[dim],
rhs.meta.shape()[dim]
);
}
if dim != ndims - 1 {
unimplemented!(
"Cross product on non-last dimension not yet implemented for CubeCL backend"
);
}
let output_shape = broadcast_shape(&[&lhs, &rhs]);
let line_size = 1;
let output = empty_device_dtype(
lhs.client.clone(),
lhs.device.clone(),
output_shape.clone(),
lhs.dtype,
);
let num_vectors = output_shape.num_elements() / 3;
let cube_dim = CubeDim::new(&lhs.client, num_vectors);
let cube_count = calculate_cube_count_elemwise(&lhs.client, num_vectors, cube_dim);
unsafe {
cross_kernel::launch_unchecked(
&lhs.client,
cube_count,
cube_dim,
address_type!(lhs, rhs, output),
linear_view_ref(&lhs, &output, line_size),
linear_view_ref(&rhs, &output, line_size),
linear_view(&output, line_size),
lhs.dtype.into(),
)
.expect("Kernel to never fail");
};
output
}