1use crate::CubeRuntime;
2use crate::element::CubeElement;
3use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric};
4use burn_common::tensor::is_contiguous;
5use burn_tensor::quantization::QTensorPrimitive;
6use burn_tensor::{DType, Shape, TensorMetadata};
7use cubecl::client::ComputeClient;
8use cubecl::frontend::Numeric;
9use cubecl::prelude::{TensorHandleRef, *};
10use cubecl::server::Handle;
11use cubecl::std::tensor::TensorHandle;
12use std::marker::PhantomData;
13
14#[derive(new)]
16pub struct CubeTensor<R: CubeRuntime> {
17 pub client: ComputeClient<R::Server, R::Channel>,
19 pub handle: Handle,
21 pub shape: Shape,
23 pub device: R::Device,
25 pub strides: Vec<usize>,
27 pub dtype: DType,
29}
30
31impl<R: CubeRuntime, E: CubeElement> From<CubeTensor<R>> for TensorHandle<R, E> {
32 fn from(val: CubeTensor<R>) -> Self {
33 TensorHandle::new(val.handle, val.shape.dims.to_vec(), val.strides.to_vec())
34 }
35}
36
37impl<R: CubeRuntime> cubecl::tune::AutotuneOutput for CubeTensor<R> {
38 #[cfg(feature = "autotune-checks")]
39 fn check_equivalence(&self, other: Self) {
40 use burn_tensor::Tolerance;
41
42 use crate::ops::into_data_sync;
43
44 match self.dtype {
45 DType::F64 => {
46 let expected = into_data_sync::<R, f64>(self.clone());
47 let actual = into_data_sync::<R, f64>(other);
48 expected.assert_approx_eq::<f64>(&actual, Tolerance::permissive());
49 }
50 DType::F32 | DType::Flex32 => {
51 let expected = into_data_sync::<R, f32>(self.clone());
52 let actual = into_data_sync::<R, f32>(other);
53 expected.assert_approx_eq::<f32>(&actual, Tolerance::permissive());
54 }
55 DType::F16 => {
56 let expected = into_data_sync::<R, half::f16>(self.clone());
57 let actual = into_data_sync::<R, half::f16>(other);
58 expected.assert_approx_eq::<half::f16>(&actual, Tolerance::permissive());
59 }
60 DType::BF16 => {
61 let expected = into_data_sync::<R, half::bf16>(self.clone());
62 let actual = into_data_sync::<R, half::bf16>(other);
63 expected.assert_approx_eq::<half::bf16>(&actual, Tolerance::permissive());
64 }
65 DType::I64 => {
66 let expected = into_data_sync::<R, i64>(self.clone());
67 let actual = into_data_sync::<R, i64>(other);
68 expected.assert_eq(&actual, true);
69 }
70 DType::I32 => {
71 let expected = into_data_sync::<R, i32>(self.clone());
72 let actual = into_data_sync::<R, i32>(other);
73 expected.assert_eq(&actual, true);
74 }
75 DType::I16 => {
76 let expected = into_data_sync::<R, i16>(self.clone());
77 let actual = into_data_sync::<R, i16>(other);
78 expected.assert_eq(&actual, true);
79 }
80 DType::I8 => {
81 let expected = into_data_sync::<R, i8>(self.clone());
82 let actual = into_data_sync::<R, i8>(other);
83 expected.assert_eq(&actual, true);
84 }
85 DType::U64 => {
86 let expected = into_data_sync::<R, u64>(self.clone());
87 let actual = into_data_sync::<R, u64>(other);
88 expected.assert_eq(&actual, true);
89 }
90 DType::U32 => {
91 let expected = into_data_sync::<R, u32>(self.clone());
92 let actual = into_data_sync::<R, u32>(other);
93 expected.assert_eq(&actual, true);
94 }
95 DType::U16 => {
96 let expected = into_data_sync::<R, u16>(self.clone());
97 let actual = into_data_sync::<R, u16>(other);
98 expected.assert_eq(&actual, true);
99 }
100 DType::U8 => {
101 let expected = into_data_sync::<R, u8>(self.clone());
102 let actual = into_data_sync::<R, u8>(other);
103 expected.assert_eq(&actual, true);
104 }
105 DType::Bool => (),
106 DType::QFloat(..) => (),
107 }
108 }
109}
110
111impl<R> core::fmt::Debug for CubeTensor<R>
112where
113 R: CubeRuntime,
114{
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 f.write_fmt(format_args!(
117 "CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}",
118 self.shape,
119 self.device,
120 self.strides,
121 self.dtype.name(),
122 R::name(&self.client),
123 ))
124 }
125}
126
127impl<R> Clone for CubeTensor<R>
128where
129 R: CubeRuntime,
130{
131 fn clone(&self) -> Self {
132 Self {
133 client: self.client.clone(),
134 handle: self.handle.clone(),
135 shape: self.shape.clone(),
136 device: self.device.clone(),
137 strides: self.strides.clone(),
138 dtype: self.dtype,
139 }
140 }
141}
142
143impl<R: CubeRuntime> TensorMetadata for CubeTensor<R> {
144 fn dtype(&self) -> DType {
145 self.dtype
146 }
147
148 fn shape(&self) -> Shape {
149 self.shape.clone()
150 }
151}
152
153impl<R: CubeRuntime> QTensorPrimitive for CubeTensor<R> {
154 fn scheme(&self) -> &burn_tensor::quantization::QuantScheme {
155 if let DType::QFloat(scheme) = &self.dtype {
156 scheme
157 } else {
158 panic!(
159 "Quantization scheme is not valid for dtype {:?}",
160 self.dtype,
161 )
162 }
163 }
164}
165
166#[macro_export]
172macro_rules! execute_with_dtype {
173 (float($dtype:expr), $element:ident, $op:expr) => {{
174 match $dtype {
175 burn_tensor::DType::F64 => {
176 type $element = f64;
177 $op
178 }
179 burn_tensor::DType::F32 => {
180 type $element = f32;
181 $op
182 }
183 burn_tensor::DType::Flex32 => {
184 type $element = cubecl::flex32;
185 $op
186 }
187
188 burn_tensor::DType::F16 => {
189 type $element = half::f16;
190 $op
191 }
192 burn_tensor::DType::BF16 => {
193 type $element = half::bf16;
194 $op
195 }
196 _ => unimplemented!("Unsupported dtype {:?}", $dtype),
197 }
198 }};
199
200 (float($lhs_dtype:expr, $rhs_dtype:expr), $element:ident, $op:expr) => {{
201 if $lhs_dtype != $rhs_dtype {
203 panic!(
204 "Data type mismatch (lhs: {:?}, rhs: {:?})",
205 $lhs_dtype, $rhs_dtype
206 );
207 }
208 execute_with_dtype!(float($lhs_dtype), $element, $op)
209 }};
210 ($dtype:expr, $element:ident, $op:expr) => {{
211 match $dtype {
212 burn_tensor::DType::F64 => {
213 type $element = f64;
214 $op
215 }
216 burn_tensor::DType::F32 => {
217 type $element = f32;
218 $op
219 }
220 burn_tensor::DType::Flex32 => {
221 type $element = cubecl::flex32;
222 $op
223 }
224 burn_tensor::DType::F16 => {
225 type $element = half::f16;
226 $op
227 }
228 burn_tensor::DType::BF16 => {
229 type $element = half::bf16;
230 $op
231 }
232 burn_tensor::DType::U64 => {
233 type $element = u64;
234 $op
235 }
236 burn_tensor::DType::U32 => {
237 type $element = u32;
238 $op
239 }
240 burn_tensor::DType::U16 => {
241 type $element = u16;
242 $op
243 }
244 burn_tensor::DType::U8 => {
245 type $element = u8;
246 $op
247 }
248 burn_tensor::DType::I64 => {
249 type $element = i64;
250 $op
251 }
252 burn_tensor::DType::I32 => {
253 type $element = i32;
254 $op
255 }
256 burn_tensor::DType::I16 => {
257 type $element = i16;
258 $op
259 }
260 burn_tensor::DType::I8 => {
261 type $element = i8;
262 $op
263 }
264 burn_tensor::DType::QFloat(_) => {
270 type $element = u32;
271 $op
272 }
273 _ => unimplemented!("Unsupported dtype {:?}", $dtype),
274 }
275 }};
276}
277
278impl<R> CubeTensor<R>
279where
280 R: CubeRuntime,
281{
282 pub fn new_contiguous(
284 client: ComputeClient<R::Server, R::Channel>,
285 device: R::Device,
286 shape: Shape,
287 handle: Handle,
288 dtype: DType,
289 ) -> Self {
290 let ndims = shape.num_dims();
291 let mut strides = vec![0; ndims];
292 let mut current = 1;
293
294 shape
295 .dims
296 .iter()
297 .enumerate()
298 .rev()
299 .for_each(|(index, val)| {
300 strides[index] = current;
301 current *= val;
302 });
303
304 Self {
305 client,
306 handle,
307 shape,
308 strides,
309 device,
310 dtype,
311 }
312 }
313
314 pub fn to_client(
316 &self,
317 client: ComputeClient<R::Server, R::Channel>,
318 device: R::Device,
319 ) -> Self {
320 let bytes = self.client.read_one(self.handle.clone().binding());
321 let handle = client.create(&bytes);
322
323 Self {
324 client,
325 handle,
326 shape: self.shape.clone(),
327 strides: self.strides.clone(),
328 device,
329 dtype: self.dtype,
330 }
331 }
332
333 pub fn as_handle_ref(&self) -> TensorHandleRef<'_, R> {
335 TensorHandleRef {
336 handle: &self.handle,
337 strides: &self.strides,
338 shape: &self.shape.dims,
339 runtime: PhantomData,
340 elem_size: self.elem_size(),
341 }
342 }
343
344 fn elem_size(&self) -> usize {
345 if let DType::QFloat(_) = self.dtype {
346 core::mem::size_of::<u32>()
348 } else {
349 self.dtype.size()
350 }
351 }
352
353 pub fn as_tensor_arg<'a, E: CubeElement>(&'a self, line_size: u8) -> TensorArg<'a, R> {
355 let handle: TensorHandleRef<'a, R> = self.as_handle_ref();
356
357 unsafe {
358 TensorArg::from_raw_parts::<E>(handle.handle, handle.strides, handle.shape, line_size)
359 }
360 }
361
362 pub fn as_array_arg<E: CubeElement>(&self, vectorisation: u8) -> ArrayArg<'_, R> {
364 unsafe {
365 ArrayArg::from_raw_parts::<E>(
366 &self.handle,
367 self.handle.size() as usize / core::mem::size_of::<E>(),
368 vectorisation,
369 )
370 }
371 }
372
373 pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool {
374 if !self.handle.can_mut() || !self.is_contiguous_buffer() {
375 return false;
376 }
377 let ndims = self.shape.num_dims();
378
379 for i in 0..ndims {
380 let shape_lhs = self.shape.dims[i];
381 let shape_rhs = rhs.shape.dims[i];
382
383 if shape_lhs < shape_rhs {
385 return false;
386 }
387 }
388
389 true
390 }
391
392 pub fn copy(&self) -> Self {
394 struct Copy;
395
396 #[cube]
397 impl<N: Numeric> NumericUnaryOp<N> for Copy {
398 type Options = ();
399
400 fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
401 input
402 }
403 }
404
405 impl NumericUnaryOpFamily for Copy {
406 type Options<N: Numeric> = ();
407 type Unary<N: Numeric> = Self;
408 }
409
410 let tensor = self.clone();
411
412 execute_with_dtype!(
413 tensor.dtype,
414 E,
415 launch_unary_numeric::<R, E, Copy, _>(tensor, |_| ())
416 )
417 }
418
419 pub fn can_mut(&self) -> bool {
421 self.handle.can_mut()
422 }
423
424 pub fn assert_is_on_same_device(&self, other: &Self) {
426 if self.device != other.device {
427 panic!(
428 "Both tensors should be on the same device {:?} != {:?}",
429 self.device, other.device
430 );
431 }
432 }
433
434 pub fn is_contiguous(&self) -> bool {
441 is_contiguous(&self.shape.dims, &self.strides)
442 }
443
444 pub fn is_contiguous_buffer(&self) -> bool {
447 self.shape.num_elements() * self.dtype.size() == self.handle.size() as usize
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454
455 #[test]
456 fn is_contiguous_non_increasing() {
457 assert!(is_contiguous(&[3, 1], &[1, 1]));
458 }
459
460 #[test]
461 fn is_contiguous_basic() {
462 assert!(is_contiguous(&[32, 32], &[32, 1]));
463 }
464
465 #[test]
466 fn is_contiguous_permuted() {
467 assert!(!is_contiguous(&[32, 32], &[1, 32]));
468 }
469
470 #[test]
471 fn is_contiguous_slice() {
472 assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));
473 }
474
475 #[test]
476 fn is_contiguous_4d_positive() {
477 assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));
478 }
479
480 #[test]
481 fn is_contiguous_4d_negative() {
482 assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));
483 }
484
485 #[test]
487 fn is_contiguous_4d_unit_shape() {
488 assert!(!is_contiguous(&[1, 1, 1, 9], &[72, 1, 72, 8]));
489 }
490}