burn_jit/kernel/cast/
base.rs1use crate::{tensor::JitTensor, JitElement, JitRuntime};
2use cubecl::linalg::tensor::index_offset_with_layout;
3use cubecl::{calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor};
4use std::any::TypeId;
5
6#[cube(launch)]
7pub(crate) fn cast_element<I: CubePrimitive, O: CubePrimitive>(
8 input: &Tensor<Line<I>>,
9 output: &mut Tensor<Line<O>>,
10 #[comptime] rank: Option<u32>,
11) {
12 let offset_output = ABSOLUTE_POS;
13
14 if offset_output >= output.len() {
15 return;
16 }
17
18 let offset_input = index_offset_with_layout::<I, O>(
19 input,
20 output,
21 offset_output,
22 0,
23 rank.unwrap_or_else(|| output.rank()),
24 rank.is_some(),
25 );
26
27 output[offset_output] = Line::cast_from(input[offset_input]);
28}
29
30pub fn cast<R: JitRuntime, EI: JitElement, EO: JitElement>(input: JitTensor<R>) -> JitTensor<R> {
34 if TypeId::of::<EI>() == TypeId::of::<EO>() {
35 return JitTensor::new_contiguous(
36 input.client,
37 input.device,
38 input.shape,
39 input.handle,
40 input.dtype,
41 );
42 }
43
44 let rank = input.shape.num_dims();
46 let vectorization_factor =
47 tensor_vectorization_factor(&[4, 2], &input.shape.dims, &input.strides, rank - 1);
48
49 let num_elems: usize = input.shape.num_elements();
50
51 let cube_dim = CubeDim::default();
52 let cube_count =
53 calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
54 let client = input.client.clone();
55 let handle = client.empty(num_elems * core::mem::size_of::<EO>());
56 let output = JitTensor::new_contiguous(
57 client.clone(),
58 input.device.clone(),
59 input.shape.clone(),
60 handle,
61 EO::dtype(),
62 );
63
64 cast_element::launch::<EI, EO, R>(
65 &client,
66 cube_count,
67 cube_dim,
68 input.as_tensor_arg::<EI>(vectorization_factor),
69 output.as_tensor_arg::<EO>(vectorization_factor),
70 Some(rank as u32),
71 );
72
73 output
74}