1use crate::{element::JitElement, kernel, tensor::JitTensor, BoolElement, JitRuntime};
2use burn_tensor::{Shape, TensorData};
3use cubecl::tensor_vectorization_factor;
4
5pub(crate) fn from_data<R: JitRuntime, E: JitElement>(
6 data: TensorData,
7 device: &R::Device,
8) -> JitTensor<R> {
9 let shape: Shape = (&data.shape).into();
10 let client = R::client(device);
11 let buffer = client.create(data.convert::<E>().as_bytes());
12
13 JitTensor::new_contiguous(client, device.clone(), shape, buffer, E::dtype())
14}
15
16pub(crate) async fn into_data<R: JitRuntime, E: JitElement>(tensor: JitTensor<R>) -> TensorData {
17 let tensor = kernel::into_contiguous(tensor);
18
19 let bytes = tensor.client.read_one_async(tensor.handle.binding()).await;
20 let actual_len = tensor.shape.num_elements() * size_of::<E>();
21 TensorData::new(E::from_bytes(&bytes[..actual_len]).to_vec(), tensor.shape)
22}
23
24#[allow(unused, reason = "useful for debugging kernels")]
26pub fn into_data_sync<R: JitRuntime, E: JitElement>(tensor: JitTensor<R>) -> TensorData {
27 let tensor = kernel::into_contiguous(tensor);
28
29 let bytes = tensor.client.read_one(tensor.handle.binding());
30 let actual_len = tensor.shape.num_elements() * size_of::<E>();
31 TensorData::new(E::from_bytes(&bytes[..actual_len]).to_vec(), tensor.shape)
32}
33
34pub(crate) async fn bool_into_data<R: JitRuntime, BT: BoolElement>(
35 tensor: JitTensor<R>,
36) -> TensorData {
37 let tensor = kernel::into_contiguous(tensor);
38 let bytes = tensor.client.read_one_async(tensor.handle.binding()).await;
39 let actual_len = tensor.shape.num_elements() * size_of::<BT>();
40 TensorData::new(
41 BT::from_bytes(&bytes[..actual_len])
42 .iter()
43 .map(|i| *i != BT::false_val())
44 .collect(),
45 tensor.shape,
46 )
47}
48
49pub(crate) fn to_device<R: JitRuntime>(tensor: JitTensor<R>, device: &R::Device) -> JitTensor<R> {
50 if &tensor.device == device {
51 return tensor;
52 }
53
54 let client = R::client(device);
55 tensor.to_client(client, device.clone())
56}
57
58pub(crate) fn empty<R: JitRuntime, E: JitElement>(
59 shape: Shape,
60 device: &R::Device,
61) -> JitTensor<R> {
62 let client = R::client(device);
63 let buffer = client.empty(shape.num_elements() * core::mem::size_of::<E>());
64
65 JitTensor::new_contiguous(client, device.clone(), shape, buffer, E::dtype())
66}
67
68pub(crate) fn swap_dims<R: JitRuntime>(
69 mut tensor: JitTensor<R>,
70 dim1: usize,
71 dim2: usize,
72) -> JitTensor<R> {
73 tensor.strides.swap(dim1, dim2);
74 tensor.shape.dims.swap(dim1, dim2);
75
76 tensor
77}
78
79pub fn permute<R: JitRuntime>(mut tensor: JitTensor<R>, axes: &[usize]) -> JitTensor<R> {
80 tensor.strides = axes.iter().map(|i| tensor.strides[*i]).collect();
82
83 tensor.shape.dims = axes.iter().map(|i| tensor.shape.dims[*i]).collect();
85
86 tensor
87}
88pub(crate) fn expand<R: JitRuntime>(tensor: JitTensor<R>, target_shape: Shape) -> JitTensor<R> {
89 let ndims_in = tensor.shape.num_dims();
90 let ndims_out = target_shape.num_dims();
91
92 let mut new_strides = vec![0usize; ndims_out];
94
95 let dim_diff = ndims_out.saturating_sub(ndims_in);
97
98 let mut tensor_dim_iter = tensor.shape.dims.iter().rev();
100 for i in (0..ndims_out).rev() {
101 if i >= dim_diff {
102 if let Some(&tensor_dim) = tensor_dim_iter.next() {
103 if tensor_dim == target_shape.dims[i] || tensor_dim == 1 {
104 new_strides[i] = if tensor_dim == target_shape.dims[i] {
106 tensor.strides[i - dim_diff]
107 } else {
108 0
109 };
110 } else {
111 panic!(
113 "Dimension mismatch: cannot broadcast dimension {} of tensor to target shape",
114 tensor_dim
115 );
116 }
117 } else {
118 new_strides[i] = 0;
121 }
122 } else {
123 new_strides[i] = 0;
125 }
126 }
127
128 JitTensor {
129 client: tensor.client,
130 device: tensor.device,
131 shape: target_shape,
132 strides: new_strides,
133 handle: tensor.handle,
134 dtype: tensor.dtype,
135 }
136}
137
138pub(crate) fn reshape<R: JitRuntime>(tensor: JitTensor<R>, shape: Shape) -> JitTensor<R> {
139 let tensor = kernel::into_contiguous(tensor);
141
142 JitTensor::new_contiguous(
143 tensor.client,
144 tensor.device,
145 shape,
146 tensor.handle,
147 tensor.dtype,
148 )
149}
150
151pub(crate) fn max_vectorization<R: JitRuntime>(tensor: &JitTensor<R>) -> u8 {
152 tensor_vectorization_factor(
153 R::supported_line_sizes(),
154 &tensor.shape.dims,
155 &tensor.strides,
156 tensor.shape.num_dims() - 1,
157 )
158}