burn_cubecl/kernel/cast/
bool_cast.rs1use crate::{
2 CubeRuntime,
3 kernel::utils::address_type,
4 ops::{max_vector_size, numeric::empty_device_dtype},
5 tensor::CubeTensor,
6};
7use burn_backend::TensorMetadata;
8use burn_std::DType;
9use cubecl::{
10 CubeDim, calculate_cube_count_elemwise, num_traits::One, prelude::*,
11 std::tensor::layout::linear::LinearView,
12};
13
14#[cube(launch_unchecked, address_type = "dynamic")]
15fn bool_cast_kernel<B: Int, T: Numeric, N: Size>(
16 input: &LinearView<Vector<B, N>>,
17 output: &mut LinearView<Vector<T, N>, ReadWrite>,
18 #[define(B, T)] _dtypes: [StorageType; 2],
19) {
20 if !output.is_in_bounds(ABSOLUTE_POS) {
21 terminate!();
22 }
23
24 output[ABSOLUTE_POS] = Vector::cast_from(input[ABSOLUTE_POS] & Vector::one());
25}
26
27pub fn bool_cast<R: CubeRuntime>(tensor: CubeTensor<R>, out_dtype: DType) -> CubeTensor<R> {
34 let output = empty_device_dtype(
35 tensor.client.clone(),
36 tensor.device.clone(),
37 tensor.shape(),
38 out_dtype,
39 );
40
41 let vector_size = max_vector_size(&tensor);
42 let num_elems = tensor.meta.num_elements();
43 let working_units = num_elems / vector_size as usize;
44 let cube_dim = CubeDim::new(&tensor.client, working_units);
45 let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
46
47 let dtype = tensor.dtype;
48
49 unsafe {
50 bool_cast_kernel::launch_unchecked(
51 &output.client,
52 cube_count,
53 cube_dim,
54 address_type!(tensor, output),
55 vector_size,
56 tensor.into_linear_view(),
57 output.clone().into_linear_view(),
58 [dtype.into(), out_dtype.into()],
59 )
60 };
61
62 output
63}