use alloc::{vec, vec::Vec};
use crate as cubecl;
use cubecl::prelude::*;
#[cube(launch_unchecked, address_type = "dynamic")]
pub fn kernel_shape_dim_4(lhs: &Tensor<f32>, rhs: &Tensor<f32>, out: &mut Tensor<u32>) {
if ABSOLUTE_POS >= out.len() {
terminate!();
}
out[0] = lhs.shape(0) as u32;
out[1] = lhs.shape(1) as u32;
out[2] = lhs.shape(2) as u32;
out[3] = lhs.shape(3) as u32;
out[4] = rhs.shape(0) as u32;
out[5] = rhs.shape(1) as u32;
out[6] = rhs.shape(2) as u32;
out[7] = rhs.shape(3) as u32;
out[8] = out.shape(0) as u32;
out[9] = out.shape(1) as u32;
out[10] = out.shape(2) as u32;
out[11] = out.shape(3) as u32;
}
#[cube(launch_unchecked, address_type = "dynamic")]
pub fn kernel_shape_different_ranks(lhs: &Tensor<f32>, rhs: &Tensor<f32>, out: &mut Tensor<u32>) {
if ABSOLUTE_POS >= out.len() {
terminate!();
}
out[0] = lhs.shape(0) as u32;
out[1] = lhs.shape(1) as u32;
out[2] = lhs.shape(2) as u32;
out[3] = lhs.shape(3) as u32;
out[4] = rhs.shape(0) as u32;
out[5] = rhs.shape(1) as u32;
out[6] = rhs.shape(2) as u32;
out[7] = out.shape(0) as u32;
out[8] = out.shape(1) as u32;
out[9] = lhs.rank() as u32;
out[10] = rhs.rank() as u32;
out[11] = out.rank() as u32;
}
#[cube(launch_unchecked, address_type = "dynamic")]
pub fn kernel_stride_different_ranks(lhs: &Tensor<f32>, rhs: &Tensor<f32>, out: &mut Tensor<u32>) {
if ABSOLUTE_POS >= out.len() {
terminate!();
}
out[0] = lhs.stride(0) as u32;
out[1] = lhs.stride(1) as u32;
out[2] = lhs.stride(2) as u32;
out[3] = lhs.stride(3) as u32;
out[4] = rhs.stride(0) as u32;
out[5] = rhs.stride(1) as u32;
out[6] = rhs.stride(2) as u32;
out[7] = out.stride(0) as u32;
out[8] = out.stride(1) as u32;
}
#[cube(launch_unchecked, address_type = "dynamic")]
pub fn kernel_len_different_ranks(lhs: &Tensor<f32>, rhs: &Tensor<f32>, out: &mut Tensor<u32>) {
if ABSOLUTE_POS >= out.len() {
terminate!();
}
out[0] = lhs.len() as u32;
out[1] = rhs.len() as u32;
out[2] = out.len() as u32;
}
#[cube(launch_unchecked, address_type = "dynamic")]
pub fn kernel_buffer_len<N: Size>(out: &mut Tensor<Vector<u32, N>>) {
if ABSOLUTE_POS >= out.len() {
terminate!();
}
out[0] = Vector::new(out.buffer_len() as u32);
}
pub fn test_shape_dim_4<R: Runtime>(client: ComputeClient<R>, addr_type: AddressType) {
if !client.properties().supports_address(addr_type) {
return;
}
let handle1 = client.empty(12 * core::mem::size_of::<u32>());
let handle2 = client.empty(12 * core::mem::size_of::<u32>());
let handle3 = client.empty(12 * core::mem::size_of::<u32>());
unsafe {
kernel_shape_dim_4::launch_unchecked(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_1d(1),
addr_type,
TensorArg::from_raw_parts(handle1, [1, 1, 1, 1].into(), [2, 3, 4, 5].into()),
TensorArg::from_raw_parts(handle2, [1, 1, 1, 1].into(), [9, 8, 7, 6].into()),
TensorArg::from_raw_parts(
handle3.clone(),
[1, 1, 1, 1].into(),
[10, 11, 12, 13].into(),
),
)
};
let actual = client.read_one_unchecked(handle3);
let actual = u32::from_bytes(&actual);
let expect: Vec<u32> = vec![2, 3, 4, 5, 9, 8, 7, 6, 10, 11, 12, 13];
assert_eq!(actual, &expect);
}
pub fn test_shape_different_ranks<R: Runtime>(client: ComputeClient<R>, addr_type: AddressType) {
if !client.properties().supports_address(addr_type) {
return;
}
let handle1 = client.empty(12 * core::mem::size_of::<u32>());
let handle2 = client.empty(12 * core::mem::size_of::<u32>());
let handle3 = client.empty(12 * core::mem::size_of::<u32>());
unsafe {
kernel_shape_different_ranks::launch_unchecked(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_1d(1),
addr_type,
TensorArg::from_raw_parts(handle1, [1, 1, 1, 1].into(), [2, 3, 4, 5].into()),
TensorArg::from_raw_parts(handle2, [1, 1, 1].into(), [9, 8, 7].into()),
TensorArg::from_raw_parts(handle3.clone(), [1, 1].into(), [10, 11].into()),
)
};
let actual = client.read_one_unchecked(handle3);
let actual = u32::from_bytes(&actual);
let expect: Vec<u32> = vec![2, 3, 4, 5, 9, 8, 7, 10, 11, 4, 3, 2];
assert_eq!(actual, &expect);
}
pub fn test_stride_different_ranks<R: Runtime>(client: ComputeClient<R>, addr_type: AddressType) {
if !client.properties().supports_address(addr_type) {
return;
}
let handle1 = client.empty(9 * core::mem::size_of::<u32>());
let handle2 = client.empty(9 * core::mem::size_of::<u32>());
let handle3 = client.empty(9 * core::mem::size_of::<u32>());
unsafe {
kernel_stride_different_ranks::launch_unchecked(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_1d(1),
addr_type,
TensorArg::from_raw_parts(handle1, [1, 2, 3, 4].into(), [1, 1, 1, 1].into()),
TensorArg::from_raw_parts(handle2, [4, 5, 6].into(), [1, 1, 1].into()),
TensorArg::from_raw_parts(handle3.clone(), [3, 2].into(), [1, 1].into()),
)
};
let actual = client.read_one_unchecked(handle3);
let actual = u32::from_bytes(&actual);
let expect: Vec<u32> = vec![1, 2, 3, 4, 4, 5, 6, 3, 2];
assert_eq!(actual, &expect);
}
pub fn test_len_different_ranks<R: Runtime>(client: ComputeClient<R>, addr_type: AddressType) {
if !client.properties().supports_address(addr_type) {
return;
}
let handle1 = client.empty(3 * core::mem::size_of::<u32>());
let handle2 = client.empty(3 * core::mem::size_of::<u32>());
let handle3 = client.empty(3 * core::mem::size_of::<u32>());
unsafe {
kernel_len_different_ranks::launch_unchecked(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_1d(1),
addr_type,
TensorArg::from_raw_parts(handle1, [1, 1, 1, 1].into(), [2, 3, 4, 5].into()),
TensorArg::from_raw_parts(handle2, [1, 1, 1].into(), [9, 8, 7].into()),
TensorArg::from_raw_parts(handle3.clone(), [1, 1].into(), [10, 11].into()),
)
};
let actual = client.read_one_unchecked(handle3);
let actual = u32::from_bytes(&actual);
let expect: Vec<u32> = vec![2 * 3 * 4 * 5, 9 * 8 * 7, 10 * 11];
assert_eq!(actual, &expect);
}
pub fn test_buffer_len_discontiguous<R: Runtime>(client: ComputeClient<R>, addr_type: AddressType) {
if !client.properties().supports_address(addr_type) {
return;
}
let handle1 = client.empty(64 * core::mem::size_of::<u32>());
unsafe {
kernel_buffer_len::launch_unchecked(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_1d(1),
addr_type,
1,
TensorArg::from_raw_parts(handle1.clone(), [32, 16, 4, 1].into(), [2, 2, 2, 2].into()),
)
};
let actual = client.read_one_unchecked(handle1);
let actual = u32::from_bytes(&actual);
assert_eq!(actual[0], 64);
}
pub fn test_buffer_len_vectorized<R: Runtime>(client: ComputeClient<R>, addr_type: AddressType) {
if !client.properties().supports_address(addr_type) {
return;
}
let handle1 = client.empty(32 * core::mem::size_of::<u32>());
unsafe {
kernel_buffer_len::launch_unchecked(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_1d(1),
addr_type,
4,
TensorArg::from_raw_parts(handle1.clone(), [16, 8, 4, 1].into(), [2, 2, 2, 4].into()),
)
};
let actual = client.read_one_unchecked(handle1);
let actual = u32::from_bytes(&actual);
assert_eq!(actual[0], 8);
}
pub fn test_buffer_len_offset<R: Runtime>(client: ComputeClient<R>, addr_type: AddressType) {
if !client.properties().supports_address(addr_type) {
return;
}
let handle1 = client.empty(256 * core::mem::size_of::<u32>());
let handle1 = handle1
.offset_start(64 * core::mem::size_of::<u32>() as u64)
.offset_end(64 * core::mem::size_of::<u32>() as u64);
unsafe {
kernel_buffer_len::launch_unchecked(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_1d(1),
addr_type,
2,
TensorArg::from_raw_parts(handle1.clone(), [32, 16, 4, 1].into(), [4, 4, 4, 8].into()),
)
};
let actual = client.read_one_unchecked(handle1);
let actual = u32::from_bytes(&actual);
assert_eq!(actual[0], 64);
}
#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_metadata {
() => {
mod metadata {
use super::*;
#[$crate::runtime_tests::test_log::test]
fn test_shape() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::metadata::test_shape_dim_4::<TestRuntime>(
client.clone(),
AddressType::U32,
);
cubecl_core::runtime_tests::metadata::test_shape_dim_4::<TestRuntime>(
client.clone(),
AddressType::U64,
);
cubecl_core::runtime_tests::metadata::test_shape_different_ranks::<TestRuntime>(
client.clone(),
AddressType::U32,
);
cubecl_core::runtime_tests::metadata::test_shape_different_ranks::<TestRuntime>(
client,
AddressType::U64,
);
}
#[$crate::runtime_tests::test_log::test]
fn test_stride() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::metadata::test_stride_different_ranks::<TestRuntime>(
client.clone(),
AddressType::U32,
);
cubecl_core::runtime_tests::metadata::test_stride_different_ranks::<TestRuntime>(
client,
AddressType::U64,
);
}
#[$crate::runtime_tests::test_log::test]
fn test_len() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::metadata::test_len_different_ranks::<TestRuntime>(
client.clone(),
AddressType::U32,
);
cubecl_core::runtime_tests::metadata::test_len_different_ranks::<TestRuntime>(
client,
AddressType::U64,
);
}
#[$crate::runtime_tests::test_log::test]
fn test_buffer_len_discontiguous() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::metadata::test_buffer_len_discontiguous::<TestRuntime>(
client.clone(),
AddressType::U32,
);
cubecl_core::runtime_tests::metadata::test_buffer_len_discontiguous::<TestRuntime>(
client,
AddressType::U64,
);
}
#[$crate::runtime_tests::test_log::test]
fn test_buffer_len_vectorized() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::metadata::test_buffer_len_vectorized::<TestRuntime>(
client.clone(),
AddressType::U32,
);
cubecl_core::runtime_tests::metadata::test_buffer_len_vectorized::<TestRuntime>(
client,
AddressType::U64,
);
}
#[$crate::runtime_tests::test_log::test]
fn test_buffer_len_offset() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::metadata::test_buffer_len_offset::<TestRuntime>(
client.clone(),
AddressType::U32,
);
cubecl_core::runtime_tests::metadata::test_buffer_len_offset::<TestRuntime>(
client,
AddressType::U64,
);
}
}
};
}