Skip to main content

burn_cubecl/ops/
base.rs

1use crate::{CubeRuntime, kernel, tensor::CubeTensor};
2use burn_backend::{
3    DType, ExecutionError, QTensorPrimitive, Shape, TensorData,
4    quantization::{QuantLevel, QuantStore, params_shape},
5};
6use burn_backend::{TensorMetadata, ops::unfold::calculate_unfold_shape};
7use burn_std::tensor::{ReshapeAction, contiguous_strides, reshape_action};
8use cubecl::{ir::LineSize, server::CopyDescriptor};
9use cubecl::{quant::scheme::BlockSize, tensor_line_size_parallel};
10
11pub(crate) fn from_data<R: CubeRuntime>(data: TensorData, device: &R::Device) -> CubeTensor<R> {
12    let shape: Shape = (&data.shape).into();
13    let client = R::client(device);
14    let buffer = client.create(data.bytes);
15
16    CubeTensor::new_contiguous(client, device.clone(), shape, buffer, data.dtype)
17}
18
19pub(crate) async fn into_data<R: CubeRuntime>(
20    tensor: CubeTensor<R>,
21) -> Result<TensorData, ExecutionError> {
22    let tensor = kernel::into_contiguous_aligned(tensor);
23
24    let elem_size = tensor.elem_size();
25    let shape = &tensor.shape.dims;
26    let binding = CopyDescriptor::new(tensor.handle.binding(), shape, &tensor.strides, elem_size);
27    let bytes = tensor
28        .client
29        .read_one_tensor_async(binding)
30        .await
31        .map_err(|err| ExecutionError::WithContext {
32            reason: format!("{err}"),
33        })?;
34
35    Ok(TensorData::from_bytes(bytes, tensor.shape, tensor.dtype))
36}
37
38/// Read data from a `CubeTensor` synchronously
39#[allow(unused, reason = "useful for debugging kernels")]
40pub fn into_data_sync<R: CubeRuntime>(tensor: CubeTensor<R>) -> TensorData {
41    burn_std::future::block_on(into_data(tensor)).unwrap()
42}
43
44#[cfg_attr(
45    feature = "tracing",
46    tracing::instrument(level = "trace", skip(tensor, device))
47)]
48pub(crate) fn to_device<R: CubeRuntime>(
49    tensor: CubeTensor<R>,
50    device: &R::Device,
51) -> CubeTensor<R> {
52    if &tensor.device == device {
53        return tensor;
54    }
55
56    let tensor = kernel::into_contiguous_aligned(tensor);
57    let client = R::client(device);
58    tensor.to_client(client, device.clone())
59}
60
61pub(crate) fn empty<R: CubeRuntime>(
62    shape: Shape,
63    device: &R::Device,
64    dtype: DType,
65) -> CubeTensor<R> {
66    let client = R::client(device);
67    let buffer = client.empty(shape.num_elements() * dtype.size());
68
69    CubeTensor::new_contiguous(client, device.clone(), shape, buffer, dtype)
70}
71
72pub(crate) fn swap_dims<R: CubeRuntime>(
73    mut tensor: CubeTensor<R>,
74    dim1: usize,
75    dim2: usize,
76) -> CubeTensor<R> {
77    tensor.strides.swap(dim1, dim2);
78    tensor.shape = tensor.shape.swap(dim1, dim2).unwrap();
79
80    if let DType::QFloat(scheme) = tensor.dtype
81        && let QuantLevel::Block(block_size) = scheme.level
82    {
83        let rank = tensor.rank();
84        let qparams = tensor.qparams.as_mut().unwrap();
85        let mut block_size = block_size.to_dim_vec(rank);
86        block_size.swap(dim1, dim2);
87
88        // Truncate unit dims from the start
89        let block_size = BlockSize::new_trim(block_size);
90        if block_size.len() > BlockSize::MAX_DIMS {
91            panic!("Swapped block size would exceed max dims");
92        }
93
94        qparams.scales.shape.dims.swap(dim1, dim2);
95        qparams.scales.strides.swap(dim1, dim2);
96
97        tensor.dtype = DType::QFloat(scheme.with_level(QuantLevel::Block(block_size)))
98    }
99
100    if let DType::QFloat(scheme) = &mut tensor.dtype
101        && let QuantStore::PackedU32(packed_dim) | QuantStore::PackedNative(packed_dim) =
102            &mut scheme.store
103    {
104        let rank = tensor.shape.len();
105
106        if *packed_dim == rank - dim1 - 1 {
107            *packed_dim = rank - dim2 - 1;
108        } else if *packed_dim == rank - dim2 - 1 {
109            *packed_dim = rank - dim1 - 1;
110        }
111    }
112
113    tensor
114}
115
116/// Permute a tensor's dimensions
117pub fn permute<R: CubeRuntime>(mut tensor: CubeTensor<R>, axes: &[usize]) -> CubeTensor<R> {
118    // remap strides
119    tensor.strides = axes.iter().map(|i| tensor.strides[*i]).collect();
120
121    // remap shape
122    tensor.shape = tensor.shape.permute(axes).unwrap();
123
124    if let DType::QFloat(scheme) = tensor.dtype
125        && let QuantLevel::Block(block_size) = scheme.level
126    {
127        let rank = tensor.rank();
128        let qparams = tensor.qparams.as_mut().unwrap();
129
130        let mut block_size = block_size.to_dim_vec(rank);
131        block_size = axes.iter().map(|i| block_size[*i]).collect();
132
133        // Truncate unit dims from the start
134        let block_size = block_size
135            .into_iter()
136            .skip_while(|it| *it == 1)
137            .collect::<Vec<_>>();
138        if block_size.len() > BlockSize::MAX_DIMS {
139            panic!("Swapped block size would exceed max dims");
140        }
141
142        qparams.scales.strides = axes.iter().map(|i| qparams.scales.strides[*i]).collect();
143        qparams.scales.shape = qparams.scales.shape.clone().permute(axes).unwrap();
144
145        tensor.dtype = DType::QFloat(scheme.with_level(QuantLevel::block(&block_size)))
146    }
147
148    if let DType::QFloat(scheme) = &mut tensor.dtype
149        && let QuantStore::PackedU32(packed_dim) = &mut scheme.store
150    {
151        let rank = tensor.shape.len();
152        let new_pos = axes
153            .iter()
154            .position(|axis| *axis == rank - *packed_dim - 1)
155            .unwrap_or(0);
156        *packed_dim = rank - new_pos - 1;
157    }
158
159    tensor
160}
161
162/// Permute a tensor's dimensions from NCHW to NHWC, or the N-dimensional equivalent
163pub fn permute_nchw_to_nhwc<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
164    let rank = tensor.shape.num_dims();
165    let c_dim = 1;
166
167    let mut dims = vec![0];
168    dims.extend(2..rank);
169    dims.push(c_dim);
170
171    permute(tensor, &dims)
172}
173
174/// Permute a shape's dimensions from NCHW to NHWC, or the N-dimensional equivalent
175pub fn permute_nchw_to_nhwc_shape(shape: Shape) -> Shape {
176    let rank = shape.num_dims();
177    let c_dim = 1;
178
179    let mut dims = vec![0];
180    dims.extend(2..rank);
181    dims.push(c_dim);
182
183    shape.permute(&dims).expect("Shape permute should succeed")
184}
185
186/// Permute a tensor's dimensions from NHWC to NCHW, or the N-dimensional equivalent
187pub fn permute_nhwc_to_nchw<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
188    let rank = tensor.shape.num_dims();
189    let c_dim = rank - 1;
190
191    let mut dims = vec![0];
192    dims.push(c_dim);
193    dims.extend(1..c_dim);
194
195    permute(tensor, &dims)
196}
197
198/// Permute a shape's dimensions from NHWC to NCHW, or the N-dimensional equivalent
199pub fn permute_nhwc_to_nchw_shape(shape: Shape) -> Shape {
200    let rank = shape.num_dims();
201    let c_dim = rank - 1;
202
203    let mut dims = vec![0];
204    dims.push(c_dim);
205    dims.extend(1..c_dim);
206
207    shape.permute(&dims).expect("Shape permute should succeed")
208}
209
210pub(crate) fn expand<R: CubeRuntime>(tensor: CubeTensor<R>, target_shape: Shape) -> CubeTensor<R> {
211    let ndims_in = tensor.shape.num_dims();
212    let ndims_out = target_shape.num_dims();
213
214    // Initialize new strides with zeros
215    let mut new_strides = vec![0usize; ndims_out];
216
217    // Calculate the difference in dimensions
218    let dim_diff = ndims_out.saturating_sub(ndims_in);
219
220    // Compare dimensions from the end, setting strides for matching dimensions or broadcasted ones
221    let mut tensor_dim_iter = tensor.shape.iter().rev();
222    for i in (0..ndims_out).rev() {
223        if i >= dim_diff {
224            if let Some(&tensor_dim) = tensor_dim_iter.next() {
225                if tensor_dim == target_shape[i] || tensor_dim == 1 {
226                    // Copy stride for non-broadcast dimensions or set to 0 for broadcast ones
227                    new_strides[i] = if tensor_dim == target_shape[i] {
228                        tensor.strides[i - dim_diff]
229                    } else {
230                        0
231                    };
232                } else {
233                    // Error handling: Dimension mismatch for broadcasting
234                    panic!(
235                        "Dimension mismatch: cannot broadcast dimension {tensor_dim} of tensor to target shape"
236                    );
237                }
238            } else {
239                // If the input tensor has fewer dimensions, treat missing dimensions as 1
240                // and set stride to 0 (broadcasting)
241                new_strides[i] = 0;
242            }
243        } else {
244            // For extra dimensions in the target shape, set stride to 0 (broadcasting)
245            new_strides[i] = 0;
246        }
247    }
248
249    // Extra check to ensure block scales must be properly handled once they're added
250    if tensor.qparams.is_some() {
251        match tensor.scheme().level {
252            QuantLevel::Tensor => {}
253            QuantLevel::Block(_) => todo!(),
254        }
255    }
256
257    CubeTensor {
258        client: tensor.client,
259        device: tensor.device,
260        shape: target_shape,
261        strides: new_strides,
262        handle: tensor.handle,
263        dtype: tensor.dtype,
264        qparams: tensor.qparams,
265    }
266}
267
268/// Reshape a jit tensor to a new shape
269pub fn reshape<R: CubeRuntime>(mut tensor: CubeTensor<R>, shape: Shape) -> CubeTensor<R> {
270    let analysis = reshape_action(&tensor.shape.dims, &tensor.strides, &shape.dims);
271
272    match analysis {
273        ReshapeAction::UpdateStrides { strides } => {
274            tensor.shape = shape;
275            tensor.strides = strides;
276            return tensor;
277        }
278        ReshapeAction::NoChange => return tensor,
279        ReshapeAction::Recompute => (),
280    }
281
282    let tensor = kernel::into_contiguous(tensor);
283
284    let mut out = CubeTensor::new_contiguous(
285        tensor.client,
286        tensor.device,
287        shape,
288        tensor.handle,
289        tensor.dtype,
290    );
291    out.qparams = tensor.qparams;
292    out
293}
294
295/// Reshape a jit tensor to a new shape
296pub fn q_reshape<R: CubeRuntime>(mut tensor: CubeTensor<R>, shape: Shape) -> CubeTensor<R> {
297    let scheme = *tensor.scheme();
298
299    let shape_values = {
300        let rank = shape.num_dims();
301        let mut shape = shape.clone();
302        shape[rank - 1] = shape[rank - 1].div_ceil(scheme.num_quants());
303        shape
304    };
305    let shape_scales = params_shape(&shape, scheme.level);
306    let (values, scales) = tensor.quantized_handles().unwrap();
307
308    let analysis_values = reshape_action(&values.shape.dims, &values.strides, &shape_values.dims);
309    let analysis_scales = reshape_action(&scales.shape.dims, &scales.strides, &shape_scales.dims);
310
311    match (analysis_values, analysis_scales) {
312        (
313            ReshapeAction::UpdateStrides { strides },
314            ReshapeAction::UpdateStrides {
315                strides: scales_strides,
316            },
317        ) => {
318            let qparams = tensor.qparams.as_mut().unwrap();
319
320            tensor.shape = shape;
321            tensor.strides = strides;
322
323            qparams.scales.shape = shape_scales;
324            qparams.scales.strides = scales_strides;
325        }
326        (ReshapeAction::UpdateStrides { strides }, ReshapeAction::NoChange) => {
327            tensor.shape = shape;
328            tensor.strides = strides;
329        }
330        (
331            ReshapeAction::NoChange,
332            ReshapeAction::UpdateStrides {
333                strides: scales_strides,
334            },
335        ) => {
336            let qparams = tensor.qparams.as_mut().unwrap();
337
338            qparams.scales.shape = shape_scales;
339            qparams.scales.strides = scales_strides;
340        }
341        (ReshapeAction::NoChange, ReshapeAction::NoChange) => {}
342        _ => {
343            tensor = kernel::into_contiguous(tensor);
344            tensor.shape = shape;
345            tensor.strides = contiguous_strides(&shape_values.dims);
346
347            let qparams = tensor.qparams.as_mut().unwrap();
348
349            qparams.scales.strides = contiguous_strides(&shape_scales.dims);
350            qparams.scales.shape = shape_scales;
351        }
352    }
353
354    tensor
355}
356
357pub(crate) fn max_line_size<R: CubeRuntime>(tensor: &CubeTensor<R>) -> LineSize {
358    tensor_line_size_parallel(
359        tensor
360            .client
361            .io_optimized_line_sizes_unchecked(tensor.dtype.size()),
362        &tensor.shape,
363        &tensor.strides,
364        tensor.shape.len() - 1,
365    )
366}
367
368pub(crate) fn max_line_size_many<R: CubeRuntime>(
369    tensors: &[&CubeTensor<R>],
370    axis: usize,
371) -> LineSize {
372    let vec = tensors
373        .iter()
374        .map(|tensor| {
375            tensor_line_size_parallel(
376                tensor
377                    .client
378                    .io_optimized_line_sizes_unchecked(tensor.dtype.size()),
379                &tensor.shape,
380                &tensor.strides,
381                axis,
382            )
383        })
384        .min();
385
386    vec.unwrap_or(0)
387}
388
389/// Unfold windows along a dimension.
390///
391/// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
392/// where windows are advanced by `step` at each index.
393///
394/// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
395///
396/// The new view will have the unfolded dimension replaced by two dimensions;
397/// one in the position of the original dimension, with size equal to the number of windows,
398/// and one appended to the right-most position, with size equal to `size`.
399///
400/// # Arguments
401///
402/// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
403/// * `dim` - the dimension to unfold.
404/// * `size` - the size of each unfolded window.
405/// * `step` - the step between each window.
406///
407/// # Returns
408///
409/// A tensor view with the shape ``[pre=..., windows, post=..., size]``.
410pub fn unfold<R: CubeRuntime>(
411    tensor: CubeTensor<R>,
412    dim: usize,
413    size: usize,
414    step: usize,
415) -> CubeTensor<R> {
416    let shape = calculate_unfold_shape(tensor.shape, dim, size, step);
417
418    let d_stride = tensor.strides[dim];
419    let mut strides = tensor.strides.clone();
420    strides[dim] = step * d_stride;
421    strides.push(d_stride);
422
423    CubeTensor {
424        shape,
425        strides,
426        ..tensor
427    }
428}