1use crate::CubeRuntime;
2use crate::element::CubeElement;
3use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric};
4use burn_backend::quantization::QuantScheme;
5use burn_backend::{DType, QTensorPrimitive, Shape, TensorMetadata};
6use burn_std::tensor::is_contiguous;
7use cubecl::client::ComputeClient;
8use cubecl::frontend::Numeric;
9use cubecl::prelude::{TensorHandleRef, *};
10use cubecl::server::Handle;
11use cubecl::std::tensor::TensorHandle;
12use std::marker::PhantomData;
13
14use super::QParams;
15
16pub struct CubeTensor<R: CubeRuntime> {
18 pub client: ComputeClient<R>,
20 pub handle: Handle,
22 pub shape: Shape,
24 pub device: R::Device,
26 pub strides: Vec<usize>,
28 pub dtype: DType,
30 pub qparams: Option<QParams>,
32}
33
34impl<R: CubeRuntime> From<CubeTensor<R>> for TensorHandle<R> {
35 fn from(val: CubeTensor<R>) -> Self {
36 TensorHandle::new(
37 val.handle,
38 val.shape.to_vec(),
39 val.strides.to_vec(),
40 val.dtype.into(),
41 )
42 }
43}
44
45impl<R: CubeRuntime> cubecl::tune::AutotuneOutput for CubeTensor<R> {
46 #[cfg(feature = "autotune-checks")]
47 fn check_equivalence(&self, other: Self) {
48 use crate::ops::into_data_sync;
49 use burn_backend::Tolerance;
50
51 let expected = into_data_sync::<R>(self.clone());
52 let actual = into_data_sync::<R>(other);
53 expected.assert_approx_eq::<f32>(&actual, Tolerance::permissive());
54 }
55}
56
57impl<R> core::fmt::Debug for CubeTensor<R>
58where
59 R: CubeRuntime,
60{
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 f.write_fmt(format_args!(
63 "CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}",
64 self.shape,
65 self.device,
66 self.strides,
67 self.dtype.name(),
68 R::name(&self.client),
69 ))
70 }
71}
72
73impl<R> Clone for CubeTensor<R>
74where
75 R: CubeRuntime,
76{
77 fn clone(&self) -> Self {
78 Self {
79 client: self.client.clone(),
80 handle: self.handle.clone(),
81 shape: self.shape.clone(),
82 device: self.device.clone(),
83 strides: self.strides.clone(),
84 dtype: self.dtype,
85 qparams: self.qparams.clone(),
86 }
87 }
88}
89
90impl<R: CubeRuntime> TensorMetadata for CubeTensor<R> {
91 fn dtype(&self) -> DType {
92 self.dtype
93 }
94
95 fn shape(&self) -> Shape {
96 self.shape.clone()
97 }
98
99 fn rank(&self) -> usize {
100 self.shape.num_dims()
101 }
102}
103
104impl<R: CubeRuntime> QTensorPrimitive for CubeTensor<R> {
105 fn scheme(&self) -> &QuantScheme {
106 if let DType::QFloat(scheme) = &self.dtype {
107 scheme
108 } else {
109 panic!(
110 "Quantization scheme is not valid for dtype {:?}",
111 self.dtype,
112 )
113 }
114 }
115}
116
117impl<R> CubeTensor<R>
118where
119 R: CubeRuntime,
120{
121 pub fn new(
123 client: ComputeClient<R>,
124 handle: Handle,
125 shape: Shape,
126 device: R::Device,
127 strides: Vec<usize>,
128 dtype: DType,
129 ) -> Self {
130 CubeTensor {
131 client,
132 handle,
133 shape,
134 device,
135 strides,
136 dtype,
137 qparams: None,
138 }
139 }
140
141 pub fn new_contiguous(
143 client: ComputeClient<R>,
144 device: R::Device,
145 shape: Shape,
146 handle: Handle,
147 dtype: DType,
148 ) -> Self {
149 let ndims = shape.num_dims();
150 let mut strides = vec![0; ndims];
151 let mut current = 1;
152
153 shape
154 .dims
155 .iter()
156 .enumerate()
157 .rev()
158 .for_each(|(index, val)| {
159 strides[index] = current;
160 current *= val;
161 });
162
163 Self {
164 client,
165 handle,
166 shape,
167 strides,
168 device,
169 dtype,
170 qparams: None,
171 }
172 }
173
174 pub fn to_client(&self, client: ComputeClient<R>, device: R::Device) -> Self {
176 let desc = self
177 .handle
178 .copy_descriptor(&self.shape.dims, &self.strides, self.elem_size());
179 let alloc = self.client.to_client_tensor(desc, &client);
180
181 Self {
182 client,
183 handle: alloc.handle,
184 shape: self.shape.clone(),
185 device,
186 strides: alloc.strides,
187 dtype: self.dtype,
188 qparams: self.qparams.clone(),
189 }
190 }
191
192 pub fn as_handle_ref(&self) -> TensorHandleRef<'_, R> {
194 TensorHandleRef {
195 handle: &self.handle,
196 strides: &self.strides,
197 shape: &self.shape.dims,
198 runtime: PhantomData,
199 elem_size: self.elem_size(),
200 }
201 }
202
203 pub fn elem_size(&self) -> usize {
205 self.dtype.size()
206 }
207
208 pub fn as_tensor_arg<'a>(&'a self, line_size: LineSize) -> TensorArg<'a, R> {
210 let size = self.dtype.size();
211 let handle: TensorHandleRef<'a, R> = self.as_handle_ref();
212
213 unsafe {
214 TensorArg::from_raw_parts_and_size(
215 handle.handle,
216 handle.strides,
217 handle.shape,
218 line_size,
219 size,
220 )
221 }
222 }
223
224 pub fn as_array_arg<E: CubeElement>(&self, line_size: LineSize) -> ArrayArg<'_, R> {
226 unsafe {
227 ArrayArg::from_raw_parts::<E>(
228 &self.handle,
229 self.handle.size() as usize / core::mem::size_of::<E>(),
230 line_size,
231 )
232 }
233 }
234
235 pub fn try_scheme(&self) -> Option<&QuantScheme> {
237 match &self.dtype {
238 DType::QFloat(scheme) => Some(scheme),
239 _ => None,
240 }
241 }
242
243 pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool {
244 if !self.handle.can_mut() || !self.is_contiguous_buffer() {
245 return false;
246 }
247 let ndims = self.shape.num_dims();
248
249 for i in 0..ndims {
250 let shape_lhs = self.shape[i];
251 let shape_rhs = rhs.shape[i];
252
253 if shape_lhs < shape_rhs {
255 return false;
256 }
257 }
258
259 true
260 }
261
262 pub fn copy(&self) -> Self {
264 struct Copy;
265
266 #[cube]
267 impl<N: Numeric> NumericUnaryOp<N> for Copy {
268 type Options = ();
269
270 fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
271 input
272 }
273 }
274
275 impl NumericUnaryOpFamily for Copy {
276 type Options = ();
277 type Unary<N: Numeric> = Self;
278 }
279
280 let tensor = self.clone();
281 launch_unary_numeric::<R, Copy, _>(tensor, |_| ())
282 }
283
284 pub fn can_mut(&self) -> bool {
286 self.handle.can_mut()
287 }
288
289 pub fn assert_is_on_same_device(&self, other: &Self) {
291 if self.device != other.device {
292 panic!(
293 "Both tensors should be on the same device {:?} != {:?}",
294 self.device, other.device
295 );
296 }
297 }
298
299 pub fn is_contiguous(&self) -> bool {
306 is_contiguous(&self.shape.dims, &self.strides)
307 }
308
309 pub fn is_contiguous_buffer(&self) -> bool {
312 self.shape.num_elements() * self.dtype.size() == self.handle.size() as usize
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319
320 #[test]
321 fn is_contiguous_non_increasing() {
322 assert!(is_contiguous(&[3, 1], &[1, 1]));
323 }
324
325 #[test]
326 fn is_contiguous_basic() {
327 assert!(is_contiguous(&[32, 32], &[32, 1]));
328 }
329
330 #[test]
331 fn is_contiguous_permuted() {
332 assert!(!is_contiguous(&[32, 32], &[1, 32]));
333 }
334
335 #[test]
336 fn is_contiguous_slice() {
337 assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));
338 }
339
340 #[test]
341 fn is_contiguous_4d_positive() {
342 assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));
343 }
344
345 #[test]
346 fn is_contiguous_4d_negative() {
347 assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));
348 }
349
350 #[test]
352 fn is_contiguous_4d_unit_shape() {
353 assert!(!is_contiguous(&[1, 1, 1, 9], &[72, 1, 72, 8]));
354 }
355}