cubecl_std/tensor/
identity.rs1use cubecl::frontend::TensorHandleRef;
2use cubecl::prelude::*;
3use cubecl::tensor_line_size_parallel;
4use cubecl_core as cubecl;
5
6use super::TensorHandle;
7
8#[cube(launch_unchecked)]
9fn identity_kernel<C: Numeric>(output: &mut Tensor<Line<C>>, gap: u32) {
10 let pos_x = ABSOLUTE_POS_X * output.line_size();
11 if ABSOLUTE_POS_Y < output.shape(0) && pos_x < output.shape(1) {
12 let mut line = Line::empty(output.line_size()).fill(C::from_int(0));
13 let offs_y = ABSOLUTE_POS_Y * output.stride(0);
14
15 let start_pos = offs_y + pos_x;
16 let mut offset = 0;
17 while offset < output.line_size() {
18 let remainder = (start_pos + offset) % gap;
19 if remainder % gap == 0 {
20 line[offset] = C::from_int(1);
21 offset += gap;
22 } else {
23 offset += gap - remainder;
24 }
25 }
26 output[start_pos / output.line_size()] = line;
27 }
28}
29
30pub fn launch<R: Runtime, C: Numeric>(
34 client: &ComputeClient<R::Server, R::Channel>,
35 output: &TensorHandle<R, C>,
36) {
37 launch_ref::<R, C>(client, &output.as_ref());
38}
39
40pub fn launch_ref<R: Runtime, C: Numeric>(
44 client: &ComputeClient<R::Server, R::Channel>,
45 output: &TensorHandleRef<R>,
46) {
47 assert_eq!(2, output.shape.len(), "input should be a matrix");
48 assert_eq!(
49 output.shape[0], output.shape[1],
50 "input should be a square matrix"
51 );
52
53 let vectorization_factor = tensor_line_size_parallel(
54 R::supported_line_sizes().iter().cloned(),
55 output.shape,
56 output.strides,
57 1,
58 );
59
60 let cube_dim = CubeDim::default();
61 let lines_x = output.shape[1] as u32 / vectorization_factor as u32;
62 let cube_count_x = lines_x.div_ceil(cube_dim.x);
63 let cube_count_y = (output.shape[0] as u32).div_ceil(cube_dim.y);
64 let cube_count = CubeCount::new_2d(cube_count_x, cube_count_y);
65
66 unsafe {
67 identity_kernel::launch_unchecked::<C, R>(
68 client,
69 cube_count,
70 cube_dim,
71 TensorArg::from_raw_parts::<C>(
72 output.handle,
73 output.strides,
74 output.shape,
75 vectorization_factor,
76 ),
77 ScalarArg::new(output.strides[0] as u32 + 1),
78 );
79 }
80}