use crate::{
CubeElement, CubeRuntime,
kernel::utils::{address_type, linear_view},
ops::{max_line_size, numeric::empty_device},
tensor::CubeTensor,
};
use burn_backend::TensorMetadata;
use cubecl::{
CubeDim, calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView,
};
#[cube(launch_unchecked, address_type = "dynamic")]
fn bool_cast_kernel<B: Int, T: Numeric>(
input: &LinearView<Line<B>>,
output: &mut LinearView<Line<T>, ReadWrite>,
#[define(B)] _input_ty: StorageType,
) {
if !output.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
output[ABSOLUTE_POS] = Line::cast_from(input[ABSOLUTE_POS] & Line::cast_from(1u32));
}
pub fn bool_cast<R: CubeRuntime, EO: CubeElement>(tensor: CubeTensor<R>) -> CubeTensor<R> {
let output =
empty_device::<R, EO>(tensor.client.clone(), tensor.device.clone(), tensor.shape());
let line_size = max_line_size(&tensor);
let num_elems = tensor.meta.num_elements();
let working_units = num_elems / line_size as usize;
let cube_dim = CubeDim::new(&tensor.client, working_units);
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
unsafe {
bool_cast_kernel::launch_unchecked::<EO, R>(
&tensor.client,
cube_count,
cube_dim,
address_type!(tensor, output),
linear_view(&tensor, line_size),
linear_view(&output, line_size),
tensor.dtype.into(),
)
.expect("Kernel to never fail");
}
output
}