burn_cubecl/kernel/
contiguous.rs

1use burn_backend::{DType, QTensorPrimitive};
2use cubecl::quant::scheme::{QuantStore, QuantValue};
3use cubecl::server::AllocationKind;
4
5use crate::{CubeRuntime, ops::empty_qtensor, tensor::CubeTensor};
6
7/// Make a jit tensor contiguous.
8pub fn into_contiguous<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
9    if tensor.is_contiguous() {
10        return tensor;
11    }
12
13    if tensor.qparams.is_some() {
14        return into_contiguous_quantized(tensor, AllocationKind::Contiguous);
15    }
16
17    let output = cubecl::std::tensor::into_contiguous_ref(
18        &tensor.client,
19        &tensor.as_handle_ref(),
20        tensor.dtype.into(),
21    )
22    .expect("Kernel to never fail");
23
24    CubeTensor::new(
25        tensor.client,
26        output.handle,
27        output.shape.into(),
28        tensor.device,
29        output.strides,
30        tensor.dtype,
31    )
32}
33
34/// Make a jit tensor contiguous with an aligned last stride. Tensor is considered already contiguous
35/// if runtime can read it as is. This is equivalent in practice.
36#[cfg_attr(
37    feature = "tracing",
38    tracing::instrument(level = "trace", skip(tensor))
39)]
40pub fn into_contiguous_aligned<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
41    if R::can_read_tensor(&tensor.shape, &tensor.strides) {
42        return tensor;
43    }
44
45    if tensor.qparams.is_some() {
46        return into_contiguous_quantized(tensor, AllocationKind::Optimized);
47    }
48
49    let output = cubecl::std::tensor::into_contiguous_pitched_ref(
50        &tensor.client,
51        &tensor.as_handle_ref(),
52        tensor.dtype.into(),
53    )
54    .expect("Kernel to never fail");
55
56    CubeTensor::new(
57        tensor.client,
58        output.handle,
59        output.shape.into(),
60        tensor.device,
61        output.strides,
62        tensor.dtype,
63    )
64}
65
66#[cfg_attr(
67    feature = "tracing",
68    tracing::instrument(level = "trace", skip(tensor))
69)]
70fn into_contiguous_quantized<R: CubeRuntime>(
71    tensor: CubeTensor<R>,
72    kind: AllocationKind,
73) -> CubeTensor<R> {
74    let scheme = tensor.scheme();
75    let output = empty_qtensor(tensor.shape.clone(), *tensor.scheme(), &tensor.device, kind);
76    let (values, scales) = tensor.quantized_handles().unwrap();
77    let (out_values, out_scales) = output.quantized_handles().unwrap();
78
79    match scheme.store {
80        QuantStore::PackedU32(packed_dim) => {
81            cubecl::std::tensor::into_contiguous_packed_ref(
82                &values.client,
83                &values.as_handle_ref(),
84                &out_values.as_handle_ref(),
85                packed_dim,
86                &tensor.shape,
87                scheme.num_quants(),
88                DType::U32.into(),
89            )
90            .expect("Kernel to never fail");
91        }
92        // e2m1 is special because it has a native packed representation, `e2m1x2`.
93        // It's internally stored as `u8` with a packing factor of 2.
94        QuantStore::PackedNative(packed_dim) if scheme.value == QuantValue::E2M1 => {
95            cubecl::std::tensor::into_contiguous_packed_ref(
96                &values.client,
97                &values.as_handle_ref(),
98                &out_values.as_handle_ref(),
99                packed_dim,
100                &tensor.shape,
101                scheme.num_quants(),
102                DType::U8.into(),
103            )
104            .expect("Kernel to never fail");
105        }
106        _ => {
107            cubecl::std::tensor::copy_into(
108                &values.client,
109                &values.as_handle_ref(),
110                &out_values.as_handle_ref(),
111                values.dtype.into(),
112            )
113            .expect("Kernel to never fail");
114        }
115    }
116
117    cubecl::std::tensor::copy_into(
118        &scales.client,
119        &scales.as_handle_ref(),
120        &out_scales.as_handle_ref(),
121        scales.dtype.into(),
122    )
123    .expect("Kernel to never fail");
124
125    output
126}