use crate::{
CubeRuntime,
kernel::utils::address_type,
ops::{max_vector_size, numeric::empty_device_dtype},
tensor::CubeTensor,
};
use burn_backend::{DType, TensorMetadata};
use cubecl::std::tensor::layout::linear::LinearView;
use cubecl::{calculate_cube_count_elemwise, prelude::*};
#[cube(launch, address_type = "dynamic")]
pub(crate) fn cast_element<I: Numeric, O: Numeric, N: Size>(
input: &LinearView<Vector<I, N>>,
output: &mut LinearView<Vector<O, N>, ReadWrite>,
#[define(I, O)] _dtypes: [StorageType; 2],
) {
if !output.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
output[ABSOLUTE_POS] = Vector::cast_from(input[ABSOLUTE_POS]);
}
pub fn cast<R: CubeRuntime>(input: CubeTensor<R>, dtype: DType) -> CubeTensor<R> {
let dtype_output = match dtype {
DType::Flex32 => DType::F32,
_ => dtype,
};
let dtype_input = match input.dtype {
DType::Flex32 => DType::F32,
_ => input.dtype,
};
if dtype_input == dtype_output {
return input;
}
let client = input.client.clone();
let vector_size = max_vector_size(&input);
let num_elems: usize = input.meta.num_elements();
let working_units = num_elems / vector_size as usize;
let cube_dim = CubeDim::new(&client, working_units);
let cube_count = calculate_cube_count_elemwise(&client, working_units, cube_dim);
let output = empty_device_dtype(
client.clone(),
input.device.clone(),
input.shape(),
dtype, );
cast_element::launch(
&client,
cube_count,
cube_dim,
address_type!(input, output),
vector_size,
input.into_linear_view(),
output.clone().into_linear_view(),
[dtype_input.into(), dtype_output.into()],
);
output
}