cubecl_std/tensor/
identity.rs1use cubecl::frontend::TensorBinding;
2use cubecl::prelude::*;
3use cubecl::tensor_vector_size_parallel;
4use cubecl_core as cubecl;
5
6use super::TensorHandle;
7
8#[cube(launch_unchecked, address_type = "dynamic")]
9fn identity_kernel<C: Numeric, N: Size>(
10 output: &mut Tensor<Vector<C, N>>,
11 gap: usize,
12 #[define(C)] _elem: StorageType,
13) {
14 let pos_x = ABSOLUTE_POS_X as usize * output.vector_size();
15 let pos_y = ABSOLUTE_POS_Y as usize;
16 if pos_y < output.shape(0) && pos_x < output.shape(1) {
17 let mut vector = Vector::new(C::from_int(0));
18 let offs_y = pos_y * output.stride(0);
19
20 let start_pos = offs_y + pos_x;
21 let mut offset = 0;
22 while offset < output.vector_size() {
23 let remainder = (start_pos + offset) % gap;
24 if remainder == 0 {
25 vector[offset] = C::from_int(1);
26 offset += gap;
27 } else {
28 offset += gap - remainder;
29 }
30 }
31 output[start_pos / output.vector_size()] = vector;
32 }
33}
34
35pub fn launch<R: Runtime>(client: &ComputeClient<R>, output: &TensorHandle<R>) {
39 let dtype = output.dtype;
40 launch_ref(client, output.clone().binding(), dtype);
41}
42
43pub fn launch_ref<R: Runtime>(
47 client: &ComputeClient<R>,
48 output: TensorBinding<R>,
49 dtype: StorageType,
50) {
51 assert_eq!(2, output.shape.len(), "input should be a matrix");
52 assert_eq!(
53 output.shape[0], output.shape[1],
54 "input should be a square matrix"
55 );
56
57 let vectorization_factor = tensor_vector_size_parallel(
58 client.io_optimized_vector_sizes(dtype.size()),
59 &output.shape,
60 &output.strides,
61 1,
62 );
63
64 let cube_dim = CubeDim::new_2d(16, 16);
65 let vectors_x = output.shape[1] as u32 / vectorization_factor as u32;
66 let cube_count_x = vectors_x.div_ceil(cube_dim.x);
67 let cube_count_y = (output.shape[0] as u32).div_ceil(cube_dim.y);
68 let cube_count = CubeCount::new_2d(cube_count_x, cube_count_y);
69
70 let scalar = output.strides[0] + 1;
71 unsafe {
72 identity_kernel::launch_unchecked(
73 client,
74 cube_count,
75 cube_dim,
76 output.required_address_type(dtype.size()),
77 vectorization_factor,
78 output.into_tensor_arg(),
79 scalar,
80 dtype,
81 )
82 }
83}