Skip to main content

burn_cubecl/kernel/
contiguous.rs

1use burn_backend::{DType, QTensorPrimitive, TensorMetadata};
2use cubecl::quant::scheme::{QuantStore, QuantValue};
3use cubecl::server::MemoryLayoutStrategy;
4
5use crate::{CubeRuntime, ops::empty_qtensor, tensor::CubeTensor};
6
7/// Make a jit tensor contiguous.
8pub fn into_contiguous<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
9    if tensor.is_contiguous() {
10        return tensor;
11    }
12
13    if tensor.qparams.is_some() {
14        return into_contiguous_quantized(tensor, MemoryLayoutStrategy::Contiguous);
15    }
16
17    let (client, device, dtype) = (tensor.client.clone(), tensor.device.clone(), tensor.dtype);
18
19    let output = cubecl::std::tensor::into_contiguous(&client, tensor.binding(), dtype.into());
20
21    CubeTensor::new(
22        client.clone(),
23        output.handle,
24        *output.metadata,
25        device,
26        dtype,
27    )
28}
29
30/// Make a jit tensor contiguous with an aligned last stride. Tensor is considered already contiguous
31/// if runtime can read it as is. This is equivalent in practice.
32#[cfg_attr(
33    feature = "tracing",
34    tracing::instrument(level = "trace", skip(tensor))
35)]
36pub fn into_contiguous_aligned<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
37    if R::can_read_tensor(tensor.meta.shape(), tensor.meta.strides()) {
38        return tensor;
39    }
40
41    if tensor.qparams.is_some() {
42        return into_contiguous_quantized(tensor, MemoryLayoutStrategy::Optimized);
43    }
44
45    let (client, device, dtype) = (tensor.client.clone(), tensor.device.clone(), tensor.dtype);
46
47    let output =
48        cubecl::std::tensor::into_contiguous_pitched(&client, tensor.binding(), dtype.into());
49
50    CubeTensor::new(
51        client.clone(),
52        output.handle,
53        *output.metadata,
54        device,
55        dtype,
56    )
57}
58
59#[cfg_attr(
60    feature = "tracing",
61    tracing::instrument(level = "trace", skip(tensor))
62)]
63fn into_contiguous_quantized<R: CubeRuntime>(
64    tensor: CubeTensor<R>,
65    strategy: MemoryLayoutStrategy,
66) -> CubeTensor<R> {
67    let scheme = tensor.scheme();
68    let output = empty_qtensor(tensor.shape(), *tensor.scheme(), &tensor.device, strategy);
69    let (values, scales) = tensor.quantized_handles().unwrap();
70    let (out_values, out_scales) = output.quantized_handles().unwrap();
71
72    let (client, dtype_scales, dtype_value) = (scales.client.clone(), scales.dtype, values.dtype);
73
74    match scheme.store {
75        QuantStore::PackedU32(packed_dim) => {
76            cubecl::std::tensor::into_contiguous_packed_ref(
77                &client,
78                values.binding(),
79                out_values.binding(),
80                packed_dim,
81                tensor.meta.shape(),
82                scheme.num_quants(),
83                DType::U32.into(),
84            );
85        }
86        // e2m1 is special because it has a native packed representation, `e2m1x2`.
87        // It's internally stored as `u8` with a packing factor of 2.
88        QuantStore::PackedNative(packed_dim) if scheme.value == QuantValue::E2M1 => {
89            cubecl::std::tensor::into_contiguous_packed_ref(
90                &client,
91                values.binding(),
92                out_values.binding(),
93                packed_dim,
94                tensor.meta.shape(),
95                scheme.num_quants(),
96                DType::U8.into(),
97            );
98        }
99        _ => {
100            cubecl::std::tensor::copy_into(
101                &client,
102                values.binding(),
103                out_values.binding(),
104                dtype_value.into(),
105            );
106        }
107    }
108
109    cubecl::std::tensor::copy_into(
110        &client,
111        scales.binding(),
112        out_scales.binding(),
113        dtype_scales.into(),
114    );
115
116    output
117}