Skip to main content

burn_cubecl/kernel/cast/
base.rs

1use crate::{
2    CubeRuntime,
3    kernel::utils::address_type,
4    ops::{max_vector_size, numeric::empty_device_dtype},
5    tensor::CubeTensor,
6};
7use burn_backend::{DType, TensorMetadata};
8use cubecl::std::tensor::layout::linear::LinearView;
9use cubecl::{calculate_cube_count_elemwise, prelude::*};
10
11#[cube(launch, address_type = "dynamic")]
12pub(crate) fn cast_element<I: Numeric, O: Numeric, N: Size>(
13    input: &LinearView<Vector<I, N>>,
14    output: &mut LinearView<Vector<O, N>, ReadWrite>,
15    #[define(I, O)] _dtypes: [StorageType; 2],
16) {
17    if !output.is_in_bounds(ABSOLUTE_POS) {
18        terminate!();
19    }
20
21    output[ABSOLUTE_POS] = Vector::cast_from(input[ABSOLUTE_POS]);
22}
23
24/// Cast a tensor to the given element type.
25///
26/// Note: When input element is semantically a boolean, prefer bool_cast function.
27pub fn cast<R: CubeRuntime>(input: CubeTensor<R>, dtype: DType) -> CubeTensor<R> {
28    let dtype_output = match dtype {
29        DType::Flex32 => DType::F32,
30        _ => dtype,
31    };
32    let dtype_input = match input.dtype {
33        DType::Flex32 => DType::F32,
34        _ => input.dtype,
35    };
36
37    if dtype_input == dtype_output {
38        return input;
39    }
40
41    let client = input.client.clone();
42
43    let vector_size = max_vector_size(&input);
44
45    let num_elems: usize = input.meta.num_elements();
46
47    let working_units = num_elems / vector_size as usize;
48    let cube_dim = CubeDim::new(&client, working_units);
49    let cube_count = calculate_cube_count_elemwise(&client, working_units, cube_dim);
50
51    let output = empty_device_dtype(
52        client.clone(),
53        input.device.clone(),
54        input.shape(),
55        dtype, // We take the same dtype as passed as input (Flex32 not F32)
56    );
57
58    cast_element::launch(
59        &client,
60        cube_count,
61        cube_dim,
62        address_type!(input, output),
63        vector_size,
64        input.into_linear_view(),
65        output.clone().into_linear_view(),
66        [dtype_input.into(), dtype_output.into()],
67    );
68
69    output
70}