cubecl_std/tensor/
identity.rs

1use cubecl::frontend::TensorHandleRef;
2use cubecl::prelude::*;
3use cubecl::tensor_line_size_parallel;
4use cubecl_core as cubecl;
5
6use super::TensorHandle;
7
8#[cube(launch_unchecked)]
9fn identity_kernel<C: Numeric>(
10    output: &mut Tensor<Line<C>>,
11    gap: usize,
12    #[define(C)] _elem: StorageType,
13) {
14    let pos_x = ABSOLUTE_POS_X as usize * output.line_size();
15    let pos_y = ABSOLUTE_POS_Y as usize;
16    if pos_y < output.shape(0) && pos_x < output.shape(1) {
17        let mut line = Line::empty(output.line_size()).fill(C::from_int(0));
18        let offs_y = pos_y * output.stride(0);
19
20        let start_pos = offs_y + pos_x;
21        let mut offset = 0;
22        while offset < output.line_size() {
23            let remainder = (start_pos + offset) % gap;
24            if remainder == 0 {
25                line[offset] = C::from_int(1);
26                offset += gap;
27            } else {
28                offset += gap - remainder;
29            }
30        }
31        output[start_pos / output.line_size()] = line;
32    }
33}
34
35/// Launch identity matrix kernel.
36/// Ensure output is a [`TensorHandle`] containing a square matrix.
37/// output will contain the identity matrix.
38pub fn launch<R: Runtime>(client: &ComputeClient<R>, output: &TensorHandle<R>) {
39    let dtype = output.dtype;
40    launch_ref(client, &output.as_ref(), dtype);
41}
42
43/// Launch identity matrix kernel by ref.
44/// Ensure output is a [`TensorHandleRef`] containing a square matrix.
45/// output will contain the identity matrix.
46pub fn launch_ref<R: Runtime>(
47    client: &ComputeClient<R>,
48    output: &TensorHandleRef<R>,
49    dtype: StorageType,
50) {
51    assert_eq!(2, output.shape.len(), "input should be a matrix");
52    assert_eq!(
53        output.shape[0], output.shape[1],
54        "input should be a square matrix"
55    );
56
57    let vectorization_factor = tensor_line_size_parallel(
58        R::supported_line_sizes().iter().cloned(),
59        output.shape,
60        output.strides,
61        1,
62    );
63
64    let cube_dim = CubeDim::new_2d(16, 16);
65    let lines_x = output.shape[1] as u32 / vectorization_factor as u32;
66    let cube_count_x = lines_x.div_ceil(cube_dim.x);
67    let cube_count_y = (output.shape[0] as u32).div_ceil(cube_dim.y);
68    let cube_count = CubeCount::new_2d(cube_count_x, cube_count_y);
69
70    unsafe {
71        identity_kernel::launch_unchecked(
72            client,
73            cube_count,
74            cube_dim,
75            TensorArg::from_raw_parts_and_size(
76                output.handle,
77                output.strides,
78                output.shape,
79                vectorization_factor,
80                dtype.size(),
81            ),
82            ScalarArg::new(output.strides[0] + 1),
83            dtype,
84        )
85        .expect("Should be able to launch the kernel all the time")
86    }
87}