burn_cubecl/kernel/
contiguous.rs1use burn_backend::{DType, QTensorPrimitive, TensorMetadata};
2use cubecl::quant::scheme::{QuantStore, QuantValue};
3use cubecl::server::MemoryLayoutStrategy;
4
5use crate::{CubeRuntime, ops::empty_qtensor, tensor::CubeTensor};
6
7pub fn into_contiguous<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
9 if tensor.is_contiguous() {
10 return tensor;
11 }
12
13 if tensor.qparams.is_some() {
14 return into_contiguous_quantized(tensor, MemoryLayoutStrategy::Contiguous);
15 }
16
17 let (client, device, dtype) = (tensor.client.clone(), tensor.device.clone(), tensor.dtype);
18
19 let output = cubecl::std::tensor::into_contiguous(&client, tensor.binding(), dtype.into());
20
21 CubeTensor::new(
22 client.clone(),
23 output.handle,
24 *output.metadata,
25 device,
26 dtype,
27 )
28}
29
30#[cfg_attr(
33 feature = "tracing",
34 tracing::instrument(level = "trace", skip(tensor))
35)]
36pub fn into_contiguous_aligned<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
37 if R::can_read_tensor(tensor.meta.shape(), tensor.meta.strides()) {
38 return tensor;
39 }
40
41 if tensor.qparams.is_some() {
42 return into_contiguous_quantized(tensor, MemoryLayoutStrategy::Optimized);
43 }
44
45 let (client, device, dtype) = (tensor.client.clone(), tensor.device.clone(), tensor.dtype);
46
47 let output =
48 cubecl::std::tensor::into_contiguous_pitched(&client, tensor.binding(), dtype.into());
49
50 CubeTensor::new(
51 client.clone(),
52 output.handle,
53 *output.metadata,
54 device,
55 dtype,
56 )
57}
58
59#[cfg_attr(
60 feature = "tracing",
61 tracing::instrument(level = "trace", skip(tensor))
62)]
63fn into_contiguous_quantized<R: CubeRuntime>(
64 tensor: CubeTensor<R>,
65 strategy: MemoryLayoutStrategy,
66) -> CubeTensor<R> {
67 let scheme = tensor.scheme();
68 let output = empty_qtensor(tensor.shape(), *tensor.scheme(), &tensor.device, strategy);
69 let (values, scales) = tensor.quantized_handles().unwrap();
70 let (out_values, out_scales) = output.quantized_handles().unwrap();
71
72 let (client, dtype_scales, dtype_value) = (scales.client.clone(), scales.dtype, values.dtype);
73
74 match scheme.store {
75 QuantStore::PackedU32(packed_dim) => {
76 cubecl::std::tensor::into_contiguous_packed_ref(
77 &client,
78 values.binding(),
79 out_values.binding(),
80 packed_dim,
81 tensor.meta.shape(),
82 scheme.num_quants(),
83 DType::U32.into(),
84 );
85 }
86 QuantStore::PackedNative(packed_dim) if scheme.value == QuantValue::E2M1 => {
89 cubecl::std::tensor::into_contiguous_packed_ref(
90 &client,
91 values.binding(),
92 out_values.binding(),
93 packed_dim,
94 tensor.meta.shape(),
95 scheme.num_quants(),
96 DType::U8.into(),
97 );
98 }
99 _ => {
100 cubecl::std::tensor::copy_into(
101 &client,
102 values.binding(),
103 out_values.binding(),
104 dtype_value.into(),
105 );
106 }
107 }
108
109 cubecl::std::tensor::copy_into(
110 &client,
111 scales.binding(),
112 out_scales.binding(),
113 dtype_scales.into(),
114 );
115
116 output
117}