1use crate::{CubeRuntime, kernel, tensor::CubeTensor};
2use burn_backend::{
3 DType, ExecutionError, QTensorPrimitive, Shape, TensorData,
4 quantization::{QuantLevel, QuantStore, params_shape},
5};
6use burn_backend::{TensorMetadata, ops::unfold::calculate_unfold_shape};
7use burn_std::tensor::{ReshapeAction, contiguous_strides, reshape_action};
8use cubecl::{ir::LineSize, server::CopyDescriptor};
9use cubecl::{quant::scheme::BlockSize, tensor_line_size_parallel};
10
11pub(crate) fn from_data<R: CubeRuntime>(data: TensorData, device: &R::Device) -> CubeTensor<R> {
12 let shape: Shape = (&data.shape).into();
13 let client = R::client(device);
14 let buffer = client.create(data.bytes);
15
16 CubeTensor::new_contiguous(client, device.clone(), shape, buffer, data.dtype)
17}
18
19pub(crate) async fn into_data<R: CubeRuntime>(
20 tensor: CubeTensor<R>,
21) -> Result<TensorData, ExecutionError> {
22 let tensor = kernel::into_contiguous_aligned(tensor);
23
24 let elem_size = tensor.elem_size();
25 let shape = &tensor.shape.dims;
26 let binding = CopyDescriptor::new(tensor.handle.binding(), shape, &tensor.strides, elem_size);
27 let bytes = tensor
28 .client
29 .read_one_tensor_async(binding)
30 .await
31 .map_err(|err| ExecutionError::WithContext {
32 reason: format!("{err}"),
33 })?;
34
35 Ok(TensorData::from_bytes(bytes, tensor.shape, tensor.dtype))
36}
37
38#[allow(unused, reason = "useful for debugging kernels")]
40pub fn into_data_sync<R: CubeRuntime>(tensor: CubeTensor<R>) -> TensorData {
41 burn_std::future::block_on(into_data(tensor)).unwrap()
42}
43
44#[cfg_attr(
45 feature = "tracing",
46 tracing::instrument(level = "trace", skip(tensor, device))
47)]
48pub(crate) fn to_device<R: CubeRuntime>(
49 tensor: CubeTensor<R>,
50 device: &R::Device,
51) -> CubeTensor<R> {
52 if &tensor.device == device {
53 return tensor;
54 }
55
56 let tensor = kernel::into_contiguous_aligned(tensor);
57 let client = R::client(device);
58 tensor.to_client(client, device.clone())
59}
60
61pub(crate) fn empty<R: CubeRuntime>(
62 shape: Shape,
63 device: &R::Device,
64 dtype: DType,
65) -> CubeTensor<R> {
66 let client = R::client(device);
67 let buffer = client.empty(shape.num_elements() * dtype.size());
68
69 CubeTensor::new_contiguous(client, device.clone(), shape, buffer, dtype)
70}
71
72pub(crate) fn swap_dims<R: CubeRuntime>(
73 mut tensor: CubeTensor<R>,
74 dim1: usize,
75 dim2: usize,
76) -> CubeTensor<R> {
77 tensor.strides.swap(dim1, dim2);
78 tensor.shape = tensor.shape.swap(dim1, dim2).unwrap();
79
80 if let DType::QFloat(scheme) = tensor.dtype
81 && let QuantLevel::Block(block_size) = scheme.level
82 {
83 let rank = tensor.rank();
84 let qparams = tensor.qparams.as_mut().unwrap();
85 let mut block_size = block_size.to_dim_vec(rank);
86 block_size.swap(dim1, dim2);
87
88 let block_size = BlockSize::new_trim(block_size);
90 if block_size.len() > BlockSize::MAX_DIMS {
91 panic!("Swapped block size would exceed max dims");
92 }
93
94 qparams.scales.shape.dims.swap(dim1, dim2);
95 qparams.scales.strides.swap(dim1, dim2);
96
97 tensor.dtype = DType::QFloat(scheme.with_level(QuantLevel::Block(block_size)))
98 }
99
100 if let DType::QFloat(scheme) = &mut tensor.dtype
101 && let QuantStore::PackedU32(packed_dim) | QuantStore::PackedNative(packed_dim) =
102 &mut scheme.store
103 {
104 let rank = tensor.shape.len();
105
106 if *packed_dim == rank - dim1 - 1 {
107 *packed_dim = rank - dim2 - 1;
108 } else if *packed_dim == rank - dim2 - 1 {
109 *packed_dim = rank - dim1 - 1;
110 }
111 }
112
113 tensor
114}
115
116pub fn permute<R: CubeRuntime>(mut tensor: CubeTensor<R>, axes: &[usize]) -> CubeTensor<R> {
118 tensor.strides = axes.iter().map(|i| tensor.strides[*i]).collect();
120
121 tensor.shape = tensor.shape.permute(axes).unwrap();
123
124 if let DType::QFloat(scheme) = tensor.dtype
125 && let QuantLevel::Block(block_size) = scheme.level
126 {
127 let rank = tensor.rank();
128 let qparams = tensor.qparams.as_mut().unwrap();
129
130 let mut block_size = block_size.to_dim_vec(rank);
131 block_size = axes.iter().map(|i| block_size[*i]).collect();
132
133 let block_size = block_size
135 .into_iter()
136 .skip_while(|it| *it == 1)
137 .collect::<Vec<_>>();
138 if block_size.len() > BlockSize::MAX_DIMS {
139 panic!("Swapped block size would exceed max dims");
140 }
141
142 qparams.scales.strides = axes.iter().map(|i| qparams.scales.strides[*i]).collect();
143 qparams.scales.shape = qparams.scales.shape.clone().permute(axes).unwrap();
144
145 tensor.dtype = DType::QFloat(scheme.with_level(QuantLevel::block(&block_size)))
146 }
147
148 if let DType::QFloat(scheme) = &mut tensor.dtype
149 && let QuantStore::PackedU32(packed_dim) = &mut scheme.store
150 {
151 let rank = tensor.shape.len();
152 let new_pos = axes
153 .iter()
154 .position(|axis| *axis == rank - *packed_dim - 1)
155 .unwrap_or(0);
156 *packed_dim = rank - new_pos - 1;
157 }
158
159 tensor
160}
161
162pub fn permute_nchw_to_nhwc<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
164 let rank = tensor.shape.num_dims();
165 let c_dim = 1;
166
167 let mut dims = vec![0];
168 dims.extend(2..rank);
169 dims.push(c_dim);
170
171 permute(tensor, &dims)
172}
173
174pub fn permute_nchw_to_nhwc_shape(shape: Shape) -> Shape {
176 let rank = shape.num_dims();
177 let c_dim = 1;
178
179 let mut dims = vec![0];
180 dims.extend(2..rank);
181 dims.push(c_dim);
182
183 shape.permute(&dims).expect("Shape permute should succeed")
184}
185
186pub fn permute_nhwc_to_nchw<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
188 let rank = tensor.shape.num_dims();
189 let c_dim = rank - 1;
190
191 let mut dims = vec![0];
192 dims.push(c_dim);
193 dims.extend(1..c_dim);
194
195 permute(tensor, &dims)
196}
197
198pub fn permute_nhwc_to_nchw_shape(shape: Shape) -> Shape {
200 let rank = shape.num_dims();
201 let c_dim = rank - 1;
202
203 let mut dims = vec![0];
204 dims.push(c_dim);
205 dims.extend(1..c_dim);
206
207 shape.permute(&dims).expect("Shape permute should succeed")
208}
209
210pub(crate) fn expand<R: CubeRuntime>(tensor: CubeTensor<R>, target_shape: Shape) -> CubeTensor<R> {
211 let ndims_in = tensor.shape.num_dims();
212 let ndims_out = target_shape.num_dims();
213
214 let mut new_strides = vec![0usize; ndims_out];
216
217 let dim_diff = ndims_out.saturating_sub(ndims_in);
219
220 let mut tensor_dim_iter = tensor.shape.iter().rev();
222 for i in (0..ndims_out).rev() {
223 if i >= dim_diff {
224 if let Some(&tensor_dim) = tensor_dim_iter.next() {
225 if tensor_dim == target_shape[i] || tensor_dim == 1 {
226 new_strides[i] = if tensor_dim == target_shape[i] {
228 tensor.strides[i - dim_diff]
229 } else {
230 0
231 };
232 } else {
233 panic!(
235 "Dimension mismatch: cannot broadcast dimension {tensor_dim} of tensor to target shape"
236 );
237 }
238 } else {
239 new_strides[i] = 0;
242 }
243 } else {
244 new_strides[i] = 0;
246 }
247 }
248
249 if tensor.qparams.is_some() {
251 match tensor.scheme().level {
252 QuantLevel::Tensor => {}
253 QuantLevel::Block(_) => todo!(),
254 }
255 }
256
257 CubeTensor {
258 client: tensor.client,
259 device: tensor.device,
260 shape: target_shape,
261 strides: new_strides,
262 handle: tensor.handle,
263 dtype: tensor.dtype,
264 qparams: tensor.qparams,
265 }
266}
267
268pub fn reshape<R: CubeRuntime>(mut tensor: CubeTensor<R>, shape: Shape) -> CubeTensor<R> {
270 let analysis = reshape_action(&tensor.shape.dims, &tensor.strides, &shape.dims);
271
272 match analysis {
273 ReshapeAction::UpdateStrides { strides } => {
274 tensor.shape = shape;
275 tensor.strides = strides;
276 return tensor;
277 }
278 ReshapeAction::NoChange => return tensor,
279 ReshapeAction::Recompute => (),
280 }
281
282 let tensor = kernel::into_contiguous(tensor);
283
284 let mut out = CubeTensor::new_contiguous(
285 tensor.client,
286 tensor.device,
287 shape,
288 tensor.handle,
289 tensor.dtype,
290 );
291 out.qparams = tensor.qparams;
292 out
293}
294
295pub fn q_reshape<R: CubeRuntime>(mut tensor: CubeTensor<R>, shape: Shape) -> CubeTensor<R> {
297 let scheme = *tensor.scheme();
298
299 let shape_values = {
300 let rank = shape.num_dims();
301 let mut shape = shape.clone();
302 shape[rank - 1] = shape[rank - 1].div_ceil(scheme.num_quants());
303 shape
304 };
305 let shape_scales = params_shape(&shape, scheme.level);
306 let (values, scales) = tensor.quantized_handles().unwrap();
307
308 let analysis_values = reshape_action(&values.shape.dims, &values.strides, &shape_values.dims);
309 let analysis_scales = reshape_action(&scales.shape.dims, &scales.strides, &shape_scales.dims);
310
311 match (analysis_values, analysis_scales) {
312 (
313 ReshapeAction::UpdateStrides { strides },
314 ReshapeAction::UpdateStrides {
315 strides: scales_strides,
316 },
317 ) => {
318 let qparams = tensor.qparams.as_mut().unwrap();
319
320 tensor.shape = shape;
321 tensor.strides = strides;
322
323 qparams.scales.shape = shape_scales;
324 qparams.scales.strides = scales_strides;
325 }
326 (ReshapeAction::UpdateStrides { strides }, ReshapeAction::NoChange) => {
327 tensor.shape = shape;
328 tensor.strides = strides;
329 }
330 (
331 ReshapeAction::NoChange,
332 ReshapeAction::UpdateStrides {
333 strides: scales_strides,
334 },
335 ) => {
336 let qparams = tensor.qparams.as_mut().unwrap();
337
338 qparams.scales.shape = shape_scales;
339 qparams.scales.strides = scales_strides;
340 }
341 (ReshapeAction::NoChange, ReshapeAction::NoChange) => {}
342 _ => {
343 tensor = kernel::into_contiguous(tensor);
344 tensor.shape = shape;
345 tensor.strides = contiguous_strides(&shape_values.dims);
346
347 let qparams = tensor.qparams.as_mut().unwrap();
348
349 qparams.scales.strides = contiguous_strides(&shape_scales.dims);
350 qparams.scales.shape = shape_scales;
351 }
352 }
353
354 tensor
355}
356
357pub(crate) fn max_line_size<R: CubeRuntime>(tensor: &CubeTensor<R>) -> LineSize {
358 tensor_line_size_parallel(
359 tensor
360 .client
361 .io_optimized_line_sizes_unchecked(tensor.dtype.size()),
362 &tensor.shape,
363 &tensor.strides,
364 tensor.shape.len() - 1,
365 )
366}
367
368pub(crate) fn max_line_size_many<R: CubeRuntime>(
369 tensors: &[&CubeTensor<R>],
370 axis: usize,
371) -> LineSize {
372 let vec = tensors
373 .iter()
374 .map(|tensor| {
375 tensor_line_size_parallel(
376 tensor
377 .client
378 .io_optimized_line_sizes_unchecked(tensor.dtype.size()),
379 &tensor.shape,
380 &tensor.strides,
381 axis,
382 )
383 })
384 .min();
385
386 vec.unwrap_or(0)
387}
388
389pub fn unfold<R: CubeRuntime>(
411 tensor: CubeTensor<R>,
412 dim: usize,
413 size: usize,
414 step: usize,
415) -> CubeTensor<R> {
416 let shape = calculate_unfold_shape(tensor.shape, dim, size, step);
417
418 let d_stride = tensor.strides[dim];
419 let mut strides = tensor.strides.clone();
420 strides[dim] = step * d_stride;
421 strides.push(d_stride);
422
423 CubeTensor {
424 shape,
425 strides,
426 ..tensor
427 }
428}