use crate::{self as cubecl, prelude::barrier::Barrier};
use alloc::{fmt::Debug, vec, vec::Vec};
use cubecl::prelude::*;
use cubecl_ir::features::Tma;
use cubecl_runtime::{
server::{ComputeServer, CopyDescriptor, MemoryLayout},
storage::ComputeStorage,
};
use cubecl_zspace::{Shape, shape, strides};
use std::println;
#[cube(launch)]
fn tensormap_load<F: Float, N: Size>(
input: &TensorMap<F, Tiled>,
output: &mut Array<Vector<F, N>>,
) {
let barrier = Barrier::shared(CUBE_DIM, UNIT_POS == 0);
sync_async_proxy_shared();
let mut stage = SharedMemory::new_aligned(32usize * 16, 128usize);
let type_size = F::type_size();
let expected = select(UNIT_POS == 0, comptime![32 * 16 * type_size] as u32, 0);
if UNIT_POS == 0 {
barrier.tma_load_2d(input, &mut stage.to_slice_mut(), 0, 8);
}
let token = barrier.arrive_and_expect_tx(1, expected);
barrier.wait(token);
let out_pos = UNIT_POS_Y * 32 + UNIT_POS_X;
output[out_pos as usize] = stage[out_pos as usize];
}
#[cube(launch)]
fn tensormap_store<F: Float, N: Size>(
input: &Array<Vector<F, N>>,
output: &mut TensorMap<F, Tiled>,
) {
let mut shared = SharedMemory::new_aligned(32usize * 16, 128usize);
let in_pos = UNIT_POS_Y * 32 + UNIT_POS_X;
shared[in_pos as usize] = input[in_pos as usize];
sync_async_proxy_shared();
sync_cube();
if UNIT_POS == 0 {
tma_store_2d(&shared.to_slice(), output, 16, 8);
tma_group_commit();
tma_group_wait_read(0u32);
}
}
#[cube(launch)]
fn tensormap_im2col_load<F: Float, N: Size>(
input: &TensorMap<F, Im2col>,
output: &mut Tensor<Vector<F, N>>,
#[comptime] tile_m: usize,
#[comptime] kernel_h: u16,
#[comptime] kernel_w: u16,
#[comptime] channels: usize,
#[comptime] pad_h: i32,
#[comptime] pad_w: i32,
) {
let tile_k = comptime!(kernel_h as usize * kernel_w as usize);
let tile_width = tile_m * channels;
let barrier = Barrier::shared(CUBE_DIM, UNIT_POS == 0);
sync_async_proxy_shared();
let mut stage = SharedMemory::new_aligned(tile_k * tile_width, 128usize);
let type_size = F::type_size();
let expected = select(
UNIT_POS == 0,
comptime![tile_width * tile_k * type_size] as u32,
0,
);
if UNIT_POS == 0 {
#[unroll]
for kernel_y in 0..kernel_h {
#[unroll]
for kernel_x in 0..kernel_w {
let kernel_idx = kernel_y * kernel_w + kernel_x;
let slice_start = kernel_idx as usize * tile_width;
let slice_end = slice_start + tile_width;
let mut stage_slice = stage.slice_mut(slice_start, slice_end);
barrier.tma_load_im2col_4d(
input,
&mut stage_slice,
0,
-pad_h,
-pad_w,
0,
kernel_y,
kernel_x,
);
}
}
}
let token = barrier.arrive_and_expect_tx(1, expected);
barrier.wait(token);
output[ABSOLUTE_POS] = stage[ABSOLUTE_POS];
}
#[cube(launch)]
fn tensormap_metadata<F: Float, N: Size>(
input_1: &Tensor<Vector<F, N>>,
output: &mut TensorMap<F, Tiled>,
input_2: &TensorMap<F, Tiled>,
output_2: &mut Tensor<u32>,
) {
output_2[0] = input_1.shape(0) as u32;
output_2[1] = input_2.shape(0) as u32;
output_2[2] = output.shape(0) as u32;
output_2[3] = output_2.shape(0) as u32;
}
pub fn test_tensormap_load<R: Runtime, F: Float + CubeElement>(client: ComputeClient<R>)
where
<<R::Server as ComputeServer>::Storage as ComputeStorage>::Resource: Debug,
{
if !client.features().tma.contains(Tma::Base) {
println!("Skipped test_tensormap_load due to unavailability");
return;
}
let values = (0..64 * 64).map(|it| F::from_int(it)).collect::<Vec<_>>();
let shape = shape![64, 64];
let MemoryLayout {
memory: handle,
strides,
} = client.create_tensor_from_slice(F::as_bytes(&values), shape.clone(), size_of::<F>());
let input = unsafe { TensorArg::from_raw_parts(handle.clone(), strides, shape) };
let out = client.empty(16 * 32 * size_of::<F>());
tensormap_load::launch::<F, R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_2d(32, 16),
1,
TensorMapArg::new(
TiledArgs {
tile_size: shape![16, 32],
},
input,
F::as_type_native_unchecked(),
),
unsafe { ArrayArg::from_raw_parts(out.clone(), 32 * 16) },
);
let actual = client.read_one_unchecked(out);
let actual = F::from_bytes(&actual);
let expected: Vec<F> = (0..16)
.flat_map(|i| i * 64..i * 64 + 32)
.map(|it| F::from_int(it + 8))
.collect();
assert_eq!(actual, &expected);
}
pub fn test_tensormap_store<R: Runtime, F: Float + CubeElement>(client: ComputeClient<R>)
where
<<R::Server as ComputeServer>::Storage as ComputeStorage>::Resource: Debug,
{
if !client.features().tma.contains(Tma::Base) {
println!("Skipped test_tensormap_load due to unavailability");
return;
}
let values = (0..32 * 16).map(|it| F::from_int(it)).collect::<Vec<_>>();
let handle = client.create_from_slice(F::as_bytes(&values));
let out_shape = &[64, 64];
let out = client.create_tensor_from_slice(
&vec![0u8; 64 * 64 * size_of::<F>()],
out_shape.into(),
size_of::<F>(),
);
tensormap_store::launch::<F, R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_2d(32, 16),
1,
unsafe { ArrayArg::from_raw_parts(handle.clone(), 32 * 16) },
TensorMapArg::new(
TiledArgs {
tile_size: shape![16, 32],
},
unsafe {
TensorArg::from_raw_parts(out.memory.clone(), out.strides.clone(), [64, 64].into())
},
F::as_type_native_unchecked(),
),
);
let actual = client.read_one_unchecked_tensor(CopyDescriptor::new(
out.memory.clone().binding(),
out_shape.into(),
out.strides.clone(),
size_of::<F>(),
));
let actual = F::from_bytes(&actual);
let mut expected: Vec<F> = vec![F::from_int(0); 64 * 64];
for y in 0..16 {
for x in 0..32 {
let val = y * 32 + x;
let y = y + 16;
let x = x + 8;
let index = y * 64 + x;
expected[index] = F::from_int(val as i64);
}
}
assert_eq!(actual, &expected);
}
pub fn test_tensormap_load_im2col<R: Runtime, F: Float + CubeElement>(client: ComputeClient<R>)
where
<<R::Server as ComputeServer>::Storage as ComputeStorage>::Resource: Debug,
{
if !client.features().tma.contains(Tma::Base) {
println!("Skipped test_tensormap_load due to unavailability");
return;
}
let n = 1;
let h = 3;
let w = 3;
let c = 8;
let kernel_h = 2;
let kernel_w = 2;
let pad_h = 1;
let pad_w = 1;
let corner_h = pad_h - (kernel_h as i32 - 1);
let corner_w = pad_w - (kernel_w as i32 - 1);
let out_h = 4;
let out_w = 4;
let tile_m = n * out_h * out_w;
let tile_k = kernel_h * kernel_w * c;
let out_size = tile_m * tile_k;
let values = (1..h * w * c + 1)
.map(|it| F::from_int(it as i64))
.collect::<Vec<_>>();
let shape: Shape = [n, h, w, c].into();
let MemoryLayout {
memory: handle,
strides,
} = client.create_tensor_from_slice(F::as_bytes(&values), shape.clone(), size_of::<F>());
let input = unsafe { TensorArg::from_raw_parts(handle, strides, shape) };
let out_shape = [tile_k, tile_m];
let out_strides = [tile_m, 1];
let out = client.empty(out_size * size_of::<F>());
tensormap_im2col_load::launch::<F, R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_2d(tile_m as u32 * c as u32, kernel_h as u32 * kernel_w as u32),
1,
TensorMapArg::new(
Im2colArgs {
pixel_box_lower_corner: vec![-pad_h, -pad_w],
pixel_box_upper_corner: vec![corner_h, corner_w],
channels_per_pixel: c as u32,
pixels_per_column: tile_m as u32,
},
input,
F::as_type_native_unchecked(),
),
unsafe { TensorArg::from_raw_parts(out.clone(), out_strides.into(), out_shape.into()) },
tile_m,
kernel_h as u16,
kernel_w as u16,
c,
pad_h,
pad_w,
);
let actual = client.read_one_unchecked(out);
let actual = F::from_bytes(&actual);
let mut expected = vec![0, 0, 0, 0, 0, 1, 2, 3, 0, 4, 5, 6, 0, 7, 8, 9];
expected.extend([0, 0, 0, 0, 1, 2, 3, 0, 4, 5, 6, 0, 7, 8, 9, 0]);
expected.extend([0, 1, 2, 3, 0, 4, 5, 6, 0, 7, 8, 9, 0, 0, 0, 0]);
expected.extend([1, 2, 3, 0, 4, 5, 6, 0, 7, 8, 9, 0, 0, 0, 0, 0]);
let expected_actual: Vec<F> = expected
.iter()
.flat_map(|v| {
if *v == 0 {
vec![0; c]
} else {
let ch_start = (*v - 1) * c + 1;
(ch_start..ch_start + c).collect()
}
})
.map(|v| F::from_int(v as i64))
.collect();
assert_eq!(actual, &expected_actual);
}
pub fn test_tensormap_metadata<R: Runtime, F: Float + CubeElement>(client: ComputeClient<R>)
where
<<R::Server as ComputeServer>::Storage as ComputeStorage>::Resource: Debug,
{
if !client.features().tma.contains(Tma::Base) {
println!("Skipped test_tensormap_load due to unavailability");
return;
}
let in_handle_1 = client.empty(4);
let in_handle_2 = client.empty(64);
let out_handle_1 = client.empty(64);
let out_handle_2 = client.empty(size_of::<u32>() * 4);
let strides = strides![16, 1];
let input_1 = unsafe { TensorArg::from_raw_parts(in_handle_1, strides.clone(), [2, 3].into()) };
let input_2 = unsafe { TensorArg::from_raw_parts(in_handle_2, strides.clone(), [4, 5].into()) };
let output_1 =
unsafe { TensorArg::from_raw_parts(out_handle_1.clone(), strides.clone(), [6, 7].into()) };
let output_2 =
unsafe { TensorArg::from_raw_parts(out_handle_2.clone(), strides, [8, 9].into()) };
tensormap_metadata::launch::<F, R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_2d(32, 16),
1,
input_1,
TensorMapArg::new(
TiledArgs {
tile_size: shape![16, 16],
},
output_1,
F::as_type_native_unchecked(),
),
TensorMapArg::new(
TiledArgs {
tile_size: shape![16, 32],
},
input_2,
F::as_type_native_unchecked(),
),
output_2,
);
let actual = client.read_one_unchecked(out_handle_2);
let actual = u32::from_bytes(&actual);
assert_eq!(actual, &[2, 4, 6, 8]);
}
#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_tensormap {
() => {
use super::*;
#[$crate::runtime_tests::test_log::test]
fn test_tensormap_load() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::tensormap::test_tensormap_load::<TestRuntime, FloatType>(
client,
);
}
#[$crate::runtime_tests::test_log::test]
fn test_tensormap_load_im2col() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::tensormap::test_tensormap_load_im2col::<TestRuntime, FloatType>(
client,
);
}
#[$crate::runtime_tests::test_log::test]
fn test_tensormap_store() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::tensormap::test_tensormap_store::<TestRuntime, FloatType>(
client,
);
}
#[$crate::runtime_tests::test_log::test]
fn test_tensormap_metadata() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::tensormap::test_tensormap_metadata::<TestRuntime, FloatType>(
client,
);
}
};
}