cubecl_std/tensor/
identity.rs

1use cubecl::frontend::TensorHandleRef;
2use cubecl::prelude::*;
3use cubecl::tensor_line_size_parallel;
4use cubecl_core as cubecl;
5
6use super::TensorHandle;
7
8#[cube(launch_unchecked)]
9fn identity_kernel<C: Numeric>(
10    output: &mut Tensor<Line<C>>,
11    gap: u32,
12    #[define(C)] _elem: StorageType,
13) {
14    let pos_x = ABSOLUTE_POS_X * output.line_size();
15    if ABSOLUTE_POS_Y < output.shape(0) && pos_x < output.shape(1) {
16        let mut line = Line::empty(output.line_size()).fill(C::from_int(0));
17        let offs_y = ABSOLUTE_POS_Y * output.stride(0);
18
19        let start_pos = offs_y + pos_x;
20        let mut offset = 0;
21        while offset < output.line_size() {
22            let remainder = (start_pos + offset) % gap;
23            if remainder.is_multiple_of(gap) {
24                line[offset] = C::from_int(1);
25                offset += gap;
26            } else {
27                offset += gap - remainder;
28            }
29        }
30        output[start_pos / output.line_size()] = line;
31    }
32}
33
34/// Launch identity matrix kernel.
35/// Ensure output is a [`TensorHandle`] containing a square matrix.
36/// output will contain the identity matrix.
37pub fn launch<R: Runtime>(client: &ComputeClient<R::Server>, output: &TensorHandle<R>) {
38    let dtype = output.dtype;
39    launch_ref::<R>(client, &output.as_ref(), dtype);
40}
41
42/// Launch identity matrix kernel by ref.
43/// Ensure output is a [`TensorHandleRef`] containing a square matrix.
44/// output will contain the identity matrix.
45pub fn launch_ref<R: Runtime>(
46    client: &ComputeClient<R::Server>,
47    output: &TensorHandleRef<R>,
48    dtype: StorageType,
49) {
50    assert_eq!(2, output.shape.len(), "input should be a matrix");
51    assert_eq!(
52        output.shape[0], output.shape[1],
53        "input should be a square matrix"
54    );
55
56    let vectorization_factor = tensor_line_size_parallel(
57        R::supported_line_sizes().iter().cloned(),
58        output.shape,
59        output.strides,
60        1,
61    );
62
63    let cube_dim = CubeDim::default();
64    let lines_x = output.shape[1] as u32 / vectorization_factor as u32;
65    let cube_count_x = lines_x.div_ceil(cube_dim.x);
66    let cube_count_y = (output.shape[0] as u32).div_ceil(cube_dim.y);
67    let cube_count = CubeCount::new_2d(cube_count_x, cube_count_y);
68
69    unsafe {
70        identity_kernel::launch_unchecked::<R>(
71            client,
72            cube_count,
73            cube_dim,
74            TensorArg::from_raw_parts_and_size(
75                output.handle,
76                output.strides,
77                output.shape,
78                vectorization_factor,
79                dtype.size(),
80            ),
81            ScalarArg::new(output.strides[0] as u32 + 1),
82            dtype,
83        );
84    }
85}