burn_jit/kernel/cast/
base.rs

1use crate::{tensor::JitTensor, JitElement, JitRuntime};
2use cubecl::linalg::tensor::index_offset_with_layout;
3use cubecl::{calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor};
4use std::any::TypeId;
5
6#[cube(launch)]
7pub(crate) fn cast_element<I: CubePrimitive, O: CubePrimitive>(
8    input: &Tensor<Line<I>>,
9    output: &mut Tensor<Line<O>>,
10    #[comptime] rank: Option<u32>,
11) {
12    let offset_output = ABSOLUTE_POS;
13
14    if offset_output >= output.len() {
15        return;
16    }
17
18    let offset_input = index_offset_with_layout::<I, O>(
19        input,
20        output,
21        offset_output,
22        0,
23        rank.unwrap_or_else(|| output.rank()),
24        rank.is_some(),
25    );
26
27    output[offset_output] = Line::cast_from(input[offset_input]);
28}
29
30/// Cast a tensor to the given element type.
31///
32/// Note: When input element is semantically a boolean, prefer bool_cast function.
33pub fn cast<R: JitRuntime, EI: JitElement, EO: JitElement>(input: JitTensor<R>) -> JitTensor<R> {
34    if TypeId::of::<EI>() == TypeId::of::<EO>() {
35        return JitTensor::new_contiguous(
36            input.client,
37            input.device,
38            input.shape,
39            input.handle,
40            input.dtype,
41        );
42    }
43
44    // Vectorization is only enabled when the last dimension is contiguous.
45    let rank = input.shape.num_dims();
46    let vectorization_factor =
47        tensor_vectorization_factor(&[4, 2], &input.shape.dims, &input.strides, rank - 1);
48
49    let num_elems: usize = input.shape.num_elements();
50
51    let cube_dim = CubeDim::default();
52    let cube_count =
53        calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
54    let client = input.client.clone();
55    let handle = client.empty(num_elems * core::mem::size_of::<EO>());
56    let output = JitTensor::new_contiguous(
57        client.clone(),
58        input.device.clone(),
59        input.shape.clone(),
60        handle,
61        EO::dtype(),
62    );
63
64    cast_element::launch::<EI, EO, R>(
65        &client,
66        cube_count,
67        cube_dim,
68        input.as_tensor_arg::<EI>(vectorization_factor),
69        output.as_tensor_arg::<EO>(vectorization_factor),
70        Some(rank as u32),
71    );
72
73    output
74}