burn_cubecl/kernel/
contiguous.rs1use burn_backend::{DType, QTensorPrimitive};
2use cubecl::quant::scheme::{QuantStore, QuantValue};
3use cubecl::server::AllocationKind;
4
5use crate::{CubeRuntime, ops::empty_qtensor, tensor::CubeTensor};
6
7pub fn into_contiguous<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
9 if tensor.is_contiguous() {
10 return tensor;
11 }
12
13 if tensor.qparams.is_some() {
14 return into_contiguous_quantized(tensor, AllocationKind::Contiguous);
15 }
16
17 let output = cubecl::std::tensor::into_contiguous_ref(
18 &tensor.client,
19 &tensor.as_handle_ref(),
20 tensor.dtype.into(),
21 )
22 .expect("Kernel to never fail");
23
24 CubeTensor::new(
25 tensor.client,
26 output.handle,
27 output.shape.into(),
28 tensor.device,
29 output.strides,
30 tensor.dtype,
31 )
32}
33
34#[cfg_attr(
37 feature = "tracing",
38 tracing::instrument(level = "trace", skip(tensor))
39)]
40pub fn into_contiguous_aligned<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
41 if R::can_read_tensor(&tensor.shape, &tensor.strides) {
42 return tensor;
43 }
44
45 if tensor.qparams.is_some() {
46 return into_contiguous_quantized(tensor, AllocationKind::Optimized);
47 }
48
49 let output = cubecl::std::tensor::into_contiguous_pitched_ref(
50 &tensor.client,
51 &tensor.as_handle_ref(),
52 tensor.dtype.into(),
53 )
54 .expect("Kernel to never fail");
55
56 CubeTensor::new(
57 tensor.client,
58 output.handle,
59 output.shape.into(),
60 tensor.device,
61 output.strides,
62 tensor.dtype,
63 )
64}
65
66#[cfg_attr(
67 feature = "tracing",
68 tracing::instrument(level = "trace", skip(tensor))
69)]
70fn into_contiguous_quantized<R: CubeRuntime>(
71 tensor: CubeTensor<R>,
72 kind: AllocationKind,
73) -> CubeTensor<R> {
74 let scheme = tensor.scheme();
75 let output = empty_qtensor(tensor.shape.clone(), *tensor.scheme(), &tensor.device, kind);
76 let (values, scales) = tensor.quantized_handles().unwrap();
77 let (out_values, out_scales) = output.quantized_handles().unwrap();
78
79 match scheme.store {
80 QuantStore::PackedU32(packed_dim) => {
81 cubecl::std::tensor::into_contiguous_packed_ref(
82 &values.client,
83 &values.as_handle_ref(),
84 &out_values.as_handle_ref(),
85 packed_dim,
86 &tensor.shape,
87 scheme.num_quants(),
88 DType::U32.into(),
89 )
90 .expect("Kernel to never fail");
91 }
92 QuantStore::PackedNative(packed_dim) if scheme.value == QuantValue::E2M1 => {
95 cubecl::std::tensor::into_contiguous_packed_ref(
96 &values.client,
97 &values.as_handle_ref(),
98 &out_values.as_handle_ref(),
99 packed_dim,
100 &tensor.shape,
101 scheme.num_quants(),
102 DType::U8.into(),
103 )
104 .expect("Kernel to never fail");
105 }
106 _ => {
107 cubecl::std::tensor::copy_into(
108 &values.client,
109 &values.as_handle_ref(),
110 &out_values.as_handle_ref(),
111 values.dtype.into(),
112 )
113 .expect("Kernel to never fail");
114 }
115 }
116
117 cubecl::std::tensor::copy_into(
118 &scales.client,
119 &scales.as_handle_ref(),
120 &out_scales.as_handle_ref(),
121 scales.dtype.into(),
122 )
123 .expect("Kernel to never fail");
124
125 output
126}