1use crate::CubeRuntime;
2use crate::element::CubeElement;
3use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric};
4use burn_common::tensor::is_contiguous;
5use burn_tensor::quantization::QTensorPrimitive;
6use burn_tensor::{DType, Shape, TensorMetadata};
7use cubecl::client::ComputeClient;
8use cubecl::frontend::Numeric;
9use cubecl::prelude::{TensorHandleRef, *};
10use cubecl::server::Handle;
11use cubecl::std::tensor::TensorHandle;
12use std::marker::PhantomData;
13
14use super::QParams;
15
16pub struct CubeTensor<R: CubeRuntime> {
18 pub client: ComputeClient<R::Server>,
20 pub handle: Handle,
22 pub shape: Shape,
24 pub device: R::Device,
26 pub strides: Vec<usize>,
28 pub dtype: DType,
30 pub qparams: Option<QParams>,
32}
33
34impl<R: CubeRuntime, E: CubeElement> From<CubeTensor<R>> for TensorHandle<R, E> {
35 fn from(val: CubeTensor<R>) -> Self {
36 TensorHandle::new(val.handle, val.shape.to_vec(), val.strides.to_vec())
37 }
38}
39
40impl<R: CubeRuntime> cubecl::tune::AutotuneOutput for CubeTensor<R> {
41 #[cfg(feature = "autotune-checks")]
42 fn check_equivalence(&self, other: Self) {
43 use burn_tensor::Tolerance;
44
45 use crate::ops::into_data_sync;
46
47 match self.dtype {
48 DType::F64 => {
49 let expected = into_data_sync::<R, f64>(self.clone());
50 let actual = into_data_sync::<R, f64>(other);
51 expected.assert_approx_eq::<f64>(&actual, Tolerance::permissive());
52 }
53 DType::F32 | DType::Flex32 => {
54 let expected = into_data_sync::<R, f32>(self.clone());
55 let actual = into_data_sync::<R, f32>(other);
56 expected.assert_approx_eq::<f32>(&actual, Tolerance::permissive());
57 }
58 DType::F16 => {
59 let expected = into_data_sync::<R, half::f16>(self.clone());
60 let actual = into_data_sync::<R, half::f16>(other);
61 expected.assert_approx_eq::<half::f16>(&actual, Tolerance::permissive());
62 }
63 DType::BF16 => {
64 let expected = into_data_sync::<R, half::bf16>(self.clone());
65 let actual = into_data_sync::<R, half::bf16>(other);
66 expected.assert_approx_eq::<half::bf16>(&actual, Tolerance::permissive());
67 }
68 DType::I64 => {
69 let expected = into_data_sync::<R, i64>(self.clone());
70 let actual = into_data_sync::<R, i64>(other);
71 expected.assert_eq(&actual, true);
72 }
73 DType::I32 => {
74 let expected = into_data_sync::<R, i32>(self.clone());
75 let actual = into_data_sync::<R, i32>(other);
76 expected.assert_eq(&actual, true);
77 }
78 DType::I16 => {
79 let expected = into_data_sync::<R, i16>(self.clone());
80 let actual = into_data_sync::<R, i16>(other);
81 expected.assert_eq(&actual, true);
82 }
83 DType::I8 => {
84 let expected = into_data_sync::<R, i8>(self.clone());
85 let actual = into_data_sync::<R, i8>(other);
86 expected.assert_eq(&actual, true);
87 }
88 DType::U64 => {
89 let expected = into_data_sync::<R, u64>(self.clone());
90 let actual = into_data_sync::<R, u64>(other);
91 expected.assert_eq(&actual, true);
92 }
93 DType::U32 => {
94 let expected = into_data_sync::<R, u32>(self.clone());
95 let actual = into_data_sync::<R, u32>(other);
96 expected.assert_eq(&actual, true);
97 }
98 DType::U16 => {
99 let expected = into_data_sync::<R, u16>(self.clone());
100 let actual = into_data_sync::<R, u16>(other);
101 expected.assert_eq(&actual, true);
102 }
103 DType::U8 => {
104 let expected = into_data_sync::<R, u8>(self.clone());
105 let actual = into_data_sync::<R, u8>(other);
106 expected.assert_eq(&actual, true);
107 }
108 DType::Bool => (),
109 DType::QFloat(..) => (),
110 }
111 }
112}
113
114impl<R> core::fmt::Debug for CubeTensor<R>
115where
116 R: CubeRuntime,
117{
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 f.write_fmt(format_args!(
120 "CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}",
121 self.shape,
122 self.device,
123 self.strides,
124 self.dtype.name(),
125 R::name(&self.client),
126 ))
127 }
128}
129
130impl<R> Clone for CubeTensor<R>
131where
132 R: CubeRuntime,
133{
134 fn clone(&self) -> Self {
135 Self {
136 client: self.client.clone(),
137 handle: self.handle.clone(),
138 shape: self.shape.clone(),
139 device: self.device.clone(),
140 strides: self.strides.clone(),
141 dtype: self.dtype,
142 qparams: self.qparams.clone(),
143 }
144 }
145}
146
147impl<R: CubeRuntime> TensorMetadata for CubeTensor<R> {
148 fn dtype(&self) -> DType {
149 self.dtype
150 }
151
152 fn shape(&self) -> Shape {
153 self.shape.clone()
154 }
155
156 fn rank(&self) -> usize {
157 self.shape.num_dims()
158 }
159}
160
161impl<R: CubeRuntime> QTensorPrimitive for CubeTensor<R> {
162 fn scheme(&self) -> &burn_tensor::quantization::QuantScheme {
163 if let DType::QFloat(scheme) = &self.dtype {
164 scheme
165 } else {
166 panic!(
167 "Quantization scheme is not valid for dtype {:?}",
168 self.dtype,
169 )
170 }
171 }
172}
173
174#[macro_export]
180macro_rules! execute_with_dtype {
181 (float($dtype:expr), $element:ident, $op:expr) => {{
182 match $dtype {
183 burn_tensor::DType::F64 => {
184 type $element = f64;
185 $op
186 }
187 burn_tensor::DType::F32 => {
188 type $element = f32;
189 $op
190 }
191 burn_tensor::DType::Flex32 => {
192 type $element = cubecl::flex32;
193 $op
194 }
195
196 burn_tensor::DType::F16 => {
197 type $element = half::f16;
198 $op
199 }
200 burn_tensor::DType::BF16 => {
201 type $element = half::bf16;
202 $op
203 }
204 _ => unimplemented!("Unsupported dtype {:?}", $dtype),
205 }
206 }};
207
208 (float($lhs_dtype:expr, $rhs_dtype:expr), $element:ident, $op:expr) => {{
209 if $lhs_dtype != $rhs_dtype {
211 panic!(
212 "Data type mismatch (lhs: {:?}, rhs: {:?})",
213 $lhs_dtype, $rhs_dtype
214 );
215 }
216 execute_with_dtype!(float($lhs_dtype), $element, $op)
217 }};
218 (int($dtype:expr), $element:ident, $op:expr) => {{
219 match $dtype {
220 burn_tensor::DType::U64 => {
221 type $element = u64;
222 $op
223 }
224 burn_tensor::DType::U32 => {
225 type $element = u32;
226 $op
227 }
228 burn_tensor::DType::U16 => {
229 type $element = u16;
230 $op
231 }
232 burn_tensor::DType::U8 => {
233 type $element = u8;
234 $op
235 }
236 burn_tensor::DType::I64 => {
237 type $element = i64;
238 $op
239 }
240 burn_tensor::DType::I32 => {
241 type $element = i32;
242 $op
243 }
244 burn_tensor::DType::I16 => {
245 type $element = i16;
246 $op
247 }
248 burn_tensor::DType::I8 => {
249 type $element = i8;
250 $op
251 }
252 _ => unimplemented!("Unsupported dtype {:?}", $dtype),
253 }
254 }};
255 ($dtype:expr, $element:ident, $op:expr) => {{
256 match $dtype {
257 burn_tensor::DType::F64 => {
258 type $element = f64;
259 $op
260 }
261 burn_tensor::DType::F32 => {
262 type $element = f32;
263 $op
264 }
265 burn_tensor::DType::Flex32 => {
266 type $element = cubecl::flex32;
267 $op
268 }
269 burn_tensor::DType::F16 => {
270 type $element = half::f16;
271 $op
272 }
273 burn_tensor::DType::BF16 => {
274 type $element = half::bf16;
275 $op
276 }
277 burn_tensor::DType::U64 => {
278 type $element = u64;
279 $op
280 }
281 burn_tensor::DType::U32 => {
282 type $element = u32;
283 $op
284 }
285 burn_tensor::DType::U16 => {
286 type $element = u16;
287 $op
288 }
289 burn_tensor::DType::U8 => {
290 type $element = u8;
291 $op
292 }
293 burn_tensor::DType::I64 => {
294 type $element = i64;
295 $op
296 }
297 burn_tensor::DType::I32 => {
298 type $element = i32;
299 $op
300 }
301 burn_tensor::DType::I16 => {
302 type $element = i16;
303 $op
304 }
305 burn_tensor::DType::I8 => {
306 type $element = i8;
307 $op
308 }
309 burn_tensor::DType::QFloat(_) => {
315 type $element = u32;
316 $op
317 }
318 _ => unimplemented!("Unsupported dtype {:?}", $dtype),
319 }
320 }};
321}
322
323impl<R> CubeTensor<R>
324where
325 R: CubeRuntime,
326{
327 pub fn new(
329 client: ComputeClient<R::Server>,
330 handle: Handle,
331 shape: Shape,
332 device: R::Device,
333 strides: Vec<usize>,
334 dtype: DType,
335 ) -> Self {
336 CubeTensor {
337 client,
338 handle,
339 shape,
340 device,
341 strides,
342 dtype,
343 qparams: None,
344 }
345 }
346
347 pub fn new_contiguous(
349 client: ComputeClient<R::Server>,
350 device: R::Device,
351 shape: Shape,
352 handle: Handle,
353 dtype: DType,
354 ) -> Self {
355 let ndims = shape.num_dims();
356 let mut strides = vec![0; ndims];
357 let mut current = 1;
358
359 shape
360 .dims
361 .iter()
362 .enumerate()
363 .rev()
364 .for_each(|(index, val)| {
365 strides[index] = current;
366 current *= val;
367 });
368
369 Self {
370 client,
371 handle,
372 shape,
373 strides,
374 device,
375 dtype,
376 qparams: None,
377 }
378 }
379
380 pub fn to_client(&self, client: ComputeClient<R::Server>, device: R::Device) -> Self {
382 let desc = self
383 .handle
384 .copy_descriptor(&self.shape.dims, &self.strides, self.elem_size());
385 let alloc = self.client.to_client_tensor(desc, &client);
386
387 Self {
388 client,
389 handle: alloc.handle,
390 shape: self.shape.clone(),
391 device,
392 strides: alloc.strides,
393 dtype: self.dtype,
394 qparams: self.qparams.clone(),
395 }
396 }
397
398 pub fn as_handle_ref(&self) -> TensorHandleRef<'_, R> {
400 TensorHandleRef {
401 handle: &self.handle,
402 strides: &self.strides,
403 shape: &self.shape.dims,
404 runtime: PhantomData,
405 elem_size: self.elem_size(),
406 }
407 }
408
409 pub fn elem_size(&self) -> usize {
411 if let DType::QFloat(_) = self.dtype {
412 core::mem::size_of::<u32>()
414 } else {
415 self.dtype.size()
416 }
417 }
418
419 pub fn as_tensor_arg<'a, E: CubeElement>(&'a self, line_size: u8) -> TensorArg<'a, R> {
421 let handle: TensorHandleRef<'a, R> = self.as_handle_ref();
422
423 unsafe {
424 TensorArg::from_raw_parts::<E>(handle.handle, handle.strides, handle.shape, line_size)
425 }
426 }
427
428 pub fn as_array_arg<E: CubeElement>(&self, vectorisation: u8) -> ArrayArg<'_, R> {
430 unsafe {
431 ArrayArg::from_raw_parts::<E>(
432 &self.handle,
433 self.handle.size() as usize / core::mem::size_of::<E>(),
434 vectorisation,
435 )
436 }
437 }
438
439 pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool {
440 if !self.handle.can_mut() || !self.is_contiguous_buffer() {
441 return false;
442 }
443 let ndims = self.shape.num_dims();
444
445 for i in 0..ndims {
446 let shape_lhs = self.shape[i];
447 let shape_rhs = rhs.shape[i];
448
449 if shape_lhs < shape_rhs {
451 return false;
452 }
453 }
454
455 true
456 }
457
458 pub fn copy(&self) -> Self {
460 struct Copy;
461
462 #[cube]
463 impl<N: Numeric> NumericUnaryOp<N> for Copy {
464 type Options = ();
465
466 fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
467 input
468 }
469 }
470
471 impl NumericUnaryOpFamily for Copy {
472 type Options<N: Numeric> = ();
473 type Unary<N: Numeric> = Self;
474 }
475
476 let tensor = self.clone();
477
478 execute_with_dtype!(
479 tensor.dtype,
480 E,
481 launch_unary_numeric::<R, E, Copy, _>(tensor, |_| ())
482 )
483 }
484
485 pub fn can_mut(&self) -> bool {
487 self.handle.can_mut()
488 }
489
490 pub fn assert_is_on_same_device(&self, other: &Self) {
492 if self.device != other.device {
493 panic!(
494 "Both tensors should be on the same device {:?} != {:?}",
495 self.device, other.device
496 );
497 }
498 }
499
500 pub fn is_contiguous(&self) -> bool {
507 is_contiguous(&self.shape.dims, &self.strides)
508 }
509
510 pub fn is_contiguous_buffer(&self) -> bool {
513 self.shape.num_elements() * self.dtype.size() == self.handle.size() as usize
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 #[test]
522 fn is_contiguous_non_increasing() {
523 assert!(is_contiguous(&[3, 1], &[1, 1]));
524 }
525
526 #[test]
527 fn is_contiguous_basic() {
528 assert!(is_contiguous(&[32, 32], &[32, 1]));
529 }
530
531 #[test]
532 fn is_contiguous_permuted() {
533 assert!(!is_contiguous(&[32, 32], &[1, 32]));
534 }
535
536 #[test]
537 fn is_contiguous_slice() {
538 assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));
539 }
540
541 #[test]
542 fn is_contiguous_4d_positive() {
543 assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));
544 }
545
546 #[test]
547 fn is_contiguous_4d_negative() {
548 assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));
549 }
550
551 #[test]
553 fn is_contiguous_4d_unit_shape() {
554 assert!(!is_contiguous(&[1, 1, 1, 9], &[72, 1, 72, 8]));
555 }
556}