burn_cubecl/kernel/cast/
base.rs1use crate::{
2 CubeRuntime,
3 kernel::utils::address_type,
4 ops::{max_vector_size, numeric::empty_device_dtype},
5 tensor::CubeTensor,
6};
7use burn_backend::{DType, TensorMetadata};
8use cubecl::std::tensor::layout::linear::LinearView;
9use cubecl::{calculate_cube_count_elemwise, prelude::*};
10
11#[cube(launch, address_type = "dynamic")]
12pub(crate) fn cast_element<I: Numeric, O: Numeric, N: Size>(
13 input: &LinearView<Vector<I, N>>,
14 output: &mut LinearView<Vector<O, N>, ReadWrite>,
15 #[define(I, O)] _dtypes: [StorageType; 2],
16) {
17 if !output.is_in_bounds(ABSOLUTE_POS) {
18 terminate!();
19 }
20
21 output[ABSOLUTE_POS] = Vector::cast_from(input[ABSOLUTE_POS]);
22}
23
24pub fn cast<R: CubeRuntime>(input: CubeTensor<R>, dtype: DType) -> CubeTensor<R> {
28 let dtype_output = match dtype {
29 DType::Flex32 => DType::F32,
30 _ => dtype,
31 };
32 let dtype_input = match input.dtype {
33 DType::Flex32 => DType::F32,
34 _ => input.dtype,
35 };
36
37 if dtype_input == dtype_output {
38 return input;
39 }
40
41 let client = input.client.clone();
42
43 let vector_size = max_vector_size(&input);
44
45 let num_elems: usize = input.meta.num_elements();
46
47 let working_units = num_elems / vector_size as usize;
48 let cube_dim = CubeDim::new(&client, working_units);
49 let cube_count = calculate_cube_count_elemwise(&client, working_units, cube_dim);
50
51 let output = empty_device_dtype(
52 client.clone(),
53 input.device.clone(),
54 input.shape(),
55 dtype, );
57
58 cast_element::launch(
59 &client,
60 cube_count,
61 cube_dim,
62 address_type!(input, output),
63 vector_size,
64 input.into_linear_view(),
65 output.clone().into_linear_view(),
66 [dtype_input.into(), dtype_output.into()],
67 );
68
69 output
70}