burn_cubecl/tensor/
base.rs

1use crate::CubeRuntime;
2use crate::element::CubeElement;
3use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric};
4use burn_common::tensor::is_contiguous;
5use burn_tensor::quantization::QTensorPrimitive;
6use burn_tensor::{DType, Shape, TensorMetadata};
7use cubecl::client::ComputeClient;
8use cubecl::frontend::Numeric;
9use cubecl::prelude::{TensorHandleRef, *};
10use cubecl::server::Handle;
11use cubecl::std::tensor::TensorHandle;
12use std::marker::PhantomData;
13
14use super::QParams;
15
16/// The basic tensor primitive struct.
17pub struct CubeTensor<R: CubeRuntime> {
18    /// Compute client for the [runtime](CubeRuntime).
19    pub client: ComputeClient<R::Server>,
20    /// The buffer where the data are stored.
21    pub handle: Handle,
22    /// The shape of the tensor.
23    pub shape: Shape,
24    /// The device of the tensor.
25    pub device: R::Device,
26    /// The strides of the tensor.
27    pub strides: Vec<usize>,
28    /// The datatype of the tensor.
29    pub dtype: DType,
30    /// Runtime quantization parameters, if applicable
31    pub qparams: Option<QParams>,
32}
33
34impl<R: CubeRuntime, E: CubeElement> From<CubeTensor<R>> for TensorHandle<R, E> {
35    fn from(val: CubeTensor<R>) -> Self {
36        TensorHandle::new(val.handle, val.shape.to_vec(), val.strides.to_vec())
37    }
38}
39
40impl<R: CubeRuntime> cubecl::tune::AutotuneOutput for CubeTensor<R> {
41    #[cfg(feature = "autotune-checks")]
42    fn check_equivalence(&self, other: Self) {
43        use burn_tensor::Tolerance;
44
45        use crate::ops::into_data_sync;
46
47        match self.dtype {
48            DType::F64 => {
49                let expected = into_data_sync::<R, f64>(self.clone());
50                let actual = into_data_sync::<R, f64>(other);
51                expected.assert_approx_eq::<f64>(&actual, Tolerance::permissive());
52            }
53            DType::F32 | DType::Flex32 => {
54                let expected = into_data_sync::<R, f32>(self.clone());
55                let actual = into_data_sync::<R, f32>(other);
56                expected.assert_approx_eq::<f32>(&actual, Tolerance::permissive());
57            }
58            DType::F16 => {
59                let expected = into_data_sync::<R, half::f16>(self.clone());
60                let actual = into_data_sync::<R, half::f16>(other);
61                expected.assert_approx_eq::<half::f16>(&actual, Tolerance::permissive());
62            }
63            DType::BF16 => {
64                let expected = into_data_sync::<R, half::bf16>(self.clone());
65                let actual = into_data_sync::<R, half::bf16>(other);
66                expected.assert_approx_eq::<half::bf16>(&actual, Tolerance::permissive());
67            }
68            DType::I64 => {
69                let expected = into_data_sync::<R, i64>(self.clone());
70                let actual = into_data_sync::<R, i64>(other);
71                expected.assert_eq(&actual, true);
72            }
73            DType::I32 => {
74                let expected = into_data_sync::<R, i32>(self.clone());
75                let actual = into_data_sync::<R, i32>(other);
76                expected.assert_eq(&actual, true);
77            }
78            DType::I16 => {
79                let expected = into_data_sync::<R, i16>(self.clone());
80                let actual = into_data_sync::<R, i16>(other);
81                expected.assert_eq(&actual, true);
82            }
83            DType::I8 => {
84                let expected = into_data_sync::<R, i8>(self.clone());
85                let actual = into_data_sync::<R, i8>(other);
86                expected.assert_eq(&actual, true);
87            }
88            DType::U64 => {
89                let expected = into_data_sync::<R, u64>(self.clone());
90                let actual = into_data_sync::<R, u64>(other);
91                expected.assert_eq(&actual, true);
92            }
93            DType::U32 => {
94                let expected = into_data_sync::<R, u32>(self.clone());
95                let actual = into_data_sync::<R, u32>(other);
96                expected.assert_eq(&actual, true);
97            }
98            DType::U16 => {
99                let expected = into_data_sync::<R, u16>(self.clone());
100                let actual = into_data_sync::<R, u16>(other);
101                expected.assert_eq(&actual, true);
102            }
103            DType::U8 => {
104                let expected = into_data_sync::<R, u8>(self.clone());
105                let actual = into_data_sync::<R, u8>(other);
106                expected.assert_eq(&actual, true);
107            }
108            DType::Bool => (),
109            DType::QFloat(..) => (),
110        }
111    }
112}
113
114impl<R> core::fmt::Debug for CubeTensor<R>
115where
116    R: CubeRuntime,
117{
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        f.write_fmt(format_args!(
120            "CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}",
121            self.shape,
122            self.device,
123            self.strides,
124            self.dtype.name(),
125            R::name(&self.client),
126        ))
127    }
128}
129
130impl<R> Clone for CubeTensor<R>
131where
132    R: CubeRuntime,
133{
134    fn clone(&self) -> Self {
135        Self {
136            client: self.client.clone(),
137            handle: self.handle.clone(),
138            shape: self.shape.clone(),
139            device: self.device.clone(),
140            strides: self.strides.clone(),
141            dtype: self.dtype,
142            qparams: self.qparams.clone(),
143        }
144    }
145}
146
147impl<R: CubeRuntime> TensorMetadata for CubeTensor<R> {
148    fn dtype(&self) -> DType {
149        self.dtype
150    }
151
152    fn shape(&self) -> Shape {
153        self.shape.clone()
154    }
155
156    fn rank(&self) -> usize {
157        self.shape.num_dims()
158    }
159}
160
161impl<R: CubeRuntime> QTensorPrimitive for CubeTensor<R> {
162    fn scheme(&self) -> &burn_tensor::quantization::QuantScheme {
163        if let DType::QFloat(scheme) = &self.dtype {
164            scheme
165        } else {
166            panic!(
167                "Quantization scheme is not valid for dtype {:?}",
168                self.dtype,
169            )
170        }
171    }
172}
173
174/// Macro to execute a kernel/operation for a given element type.
175///
176/// # Panics
177/// Since there is no automatic type cast at this time, binary operations for different
178/// floating point precision data types will panic with a data type mismatch.
179#[macro_export]
180macro_rules! execute_with_dtype {
181    (float($dtype:expr), $element:ident, $op:expr) => {{
182        match $dtype {
183            burn_tensor::DType::F64 => {
184                type $element = f64;
185                $op
186            }
187            burn_tensor::DType::F32 => {
188                type $element = f32;
189                $op
190            }
191            burn_tensor::DType::Flex32 => {
192                type $element = cubecl::flex32;
193                $op
194            }
195
196            burn_tensor::DType::F16 => {
197                type $element = half::f16;
198                $op
199            }
200            burn_tensor::DType::BF16 => {
201                type $element = half::bf16;
202                $op
203            }
204            _ => unimplemented!("Unsupported dtype {:?}", $dtype),
205        }
206    }};
207
208    (float($lhs_dtype:expr, $rhs_dtype:expr), $element:ident, $op:expr) => {{
209        // NOTE: might be better for floating point binary operations to return a Result instead?
210        if $lhs_dtype != $rhs_dtype {
211            panic!(
212                "Data type mismatch (lhs: {:?}, rhs: {:?})",
213                $lhs_dtype, $rhs_dtype
214            );
215        }
216        execute_with_dtype!(float($lhs_dtype), $element, $op)
217    }};
218    (int($dtype:expr), $element:ident, $op:expr) => {{
219        match $dtype {
220            burn_tensor::DType::U64 => {
221                type $element = u64;
222                $op
223            }
224            burn_tensor::DType::U32 => {
225                type $element = u32;
226                $op
227            }
228            burn_tensor::DType::U16 => {
229                type $element = u16;
230                $op
231            }
232            burn_tensor::DType::U8 => {
233                type $element = u8;
234                $op
235            }
236            burn_tensor::DType::I64 => {
237                type $element = i64;
238                $op
239            }
240            burn_tensor::DType::I32 => {
241                type $element = i32;
242                $op
243            }
244            burn_tensor::DType::I16 => {
245                type $element = i16;
246                $op
247            }
248            burn_tensor::DType::I8 => {
249                type $element = i8;
250                $op
251            }
252            _ => unimplemented!("Unsupported dtype {:?}", $dtype),
253        }
254    }};
255    ($dtype:expr, $element:ident, $op:expr) => {{
256        match $dtype {
257            burn_tensor::DType::F64 => {
258                type $element = f64;
259                $op
260            }
261            burn_tensor::DType::F32 => {
262                type $element = f32;
263                $op
264            }
265            burn_tensor::DType::Flex32 => {
266                type $element = cubecl::flex32;
267                $op
268            }
269            burn_tensor::DType::F16 => {
270                type $element = half::f16;
271                $op
272            }
273            burn_tensor::DType::BF16 => {
274                type $element = half::bf16;
275                $op
276            }
277            burn_tensor::DType::U64 => {
278                type $element = u64;
279                $op
280            }
281            burn_tensor::DType::U32 => {
282                type $element = u32;
283                $op
284            }
285            burn_tensor::DType::U16 => {
286                type $element = u16;
287                $op
288            }
289            burn_tensor::DType::U8 => {
290                type $element = u8;
291                $op
292            }
293            burn_tensor::DType::I64 => {
294                type $element = i64;
295                $op
296            }
297            burn_tensor::DType::I32 => {
298                type $element = i32;
299                $op
300            }
301            burn_tensor::DType::I16 => {
302                type $element = i16;
303                $op
304            }
305            burn_tensor::DType::I8 => {
306                type $element = i8;
307                $op
308            }
309            // NOTE: bool and qfloat dtypes are actually represented as u32/u8
310            // burn_tensor::DType::Bool => {
311            //     type $element = u32/u8;
312            //     $op
313            // }
314            burn_tensor::DType::QFloat(_) => {
315                type $element = u32;
316                $op
317            }
318            _ => unimplemented!("Unsupported dtype {:?}", $dtype),
319        }
320    }};
321}
322
323impl<R> CubeTensor<R>
324where
325    R: CubeRuntime,
326{
327    /// Create a new standard tensor
328    pub fn new(
329        client: ComputeClient<R::Server>,
330        handle: Handle,
331        shape: Shape,
332        device: R::Device,
333        strides: Vec<usize>,
334        dtype: DType,
335    ) -> Self {
336        CubeTensor {
337            client,
338            handle,
339            shape,
340            device,
341            strides,
342            dtype,
343            qparams: None,
344        }
345    }
346
347    /// Create a new tensor with a contiguous memory layout.
348    pub fn new_contiguous(
349        client: ComputeClient<R::Server>,
350        device: R::Device,
351        shape: Shape,
352        handle: Handle,
353        dtype: DType,
354    ) -> Self {
355        let ndims = shape.num_dims();
356        let mut strides = vec![0; ndims];
357        let mut current = 1;
358
359        shape
360            .dims
361            .iter()
362            .enumerate()
363            .rev()
364            .for_each(|(index, val)| {
365                strides[index] = current;
366                current *= val;
367            });
368
369        Self {
370            client,
371            handle,
372            shape,
373            strides,
374            device,
375            dtype,
376            qparams: None,
377        }
378    }
379
380    /// Change the context of the current tensor and return the newly transferred tensor.
381    pub fn to_client(&self, client: ComputeClient<R::Server>, device: R::Device) -> Self {
382        let desc = self
383            .handle
384            .copy_descriptor(&self.shape.dims, &self.strides, self.elem_size());
385        let alloc = self.client.to_client_tensor(desc, &client);
386
387        Self {
388            client,
389            handle: alloc.handle,
390            shape: self.shape.clone(),
391            device,
392            strides: alloc.strides,
393            dtype: self.dtype,
394            qparams: self.qparams.clone(),
395        }
396    }
397
398    /// Return the reference to a tensor handle.
399    pub fn as_handle_ref(&self) -> TensorHandleRef<'_, R> {
400        TensorHandleRef {
401            handle: &self.handle,
402            strides: &self.strides,
403            shape: &self.shape.dims,
404            runtime: PhantomData,
405            elem_size: self.elem_size(),
406        }
407    }
408
409    /// Returns the element size of this tensor
410    pub fn elem_size(&self) -> usize {
411        if let DType::QFloat(_) = self.dtype {
412            // Encoded as u32
413            core::mem::size_of::<u32>()
414        } else {
415            self.dtype.size()
416        }
417    }
418
419    /// Return the reference to a tensor argument.
420    pub fn as_tensor_arg<'a, E: CubeElement>(&'a self, line_size: u8) -> TensorArg<'a, R> {
421        let handle: TensorHandleRef<'a, R> = self.as_handle_ref();
422
423        unsafe {
424            TensorArg::from_raw_parts::<E>(handle.handle, handle.strides, handle.shape, line_size)
425        }
426    }
427
428    /// Return the reference to an array argument.
429    pub fn as_array_arg<E: CubeElement>(&self, vectorisation: u8) -> ArrayArg<'_, R> {
430        unsafe {
431            ArrayArg::from_raw_parts::<E>(
432                &self.handle,
433                self.handle.size() as usize / core::mem::size_of::<E>(),
434                vectorisation,
435            )
436        }
437    }
438
439    pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool {
440        if !self.handle.can_mut() || !self.is_contiguous_buffer() {
441            return false;
442        }
443        let ndims = self.shape.num_dims();
444
445        for i in 0..ndims {
446            let shape_lhs = self.shape[i];
447            let shape_rhs = rhs.shape[i];
448
449            // Output tensor will be different from the mutable tensor.
450            if shape_lhs < shape_rhs {
451                return false;
452            }
453        }
454
455        true
456    }
457
458    /// Copy the current tensor.
459    pub fn copy(&self) -> Self {
460        struct Copy;
461
462        #[cube]
463        impl<N: Numeric> NumericUnaryOp<N> for Copy {
464            type Options = ();
465
466            fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
467                input
468            }
469        }
470
471        impl NumericUnaryOpFamily for Copy {
472            type Options<N: Numeric> = ();
473            type Unary<N: Numeric> = Self;
474        }
475
476        let tensor = self.clone();
477
478        execute_with_dtype!(
479            tensor.dtype,
480            E,
481            launch_unary_numeric::<R, E, Copy, _>(tensor, |_| ())
482        )
483    }
484
485    /// Check if the tensor is safe to mutate.
486    pub fn can_mut(&self) -> bool {
487        self.handle.can_mut()
488    }
489
490    /// Assert that both tensors are on the same device.
491    pub fn assert_is_on_same_device(&self, other: &Self) {
492        if self.device != other.device {
493            panic!(
494                "Both tensors should be on the same device {:?} != {:?}",
495                self.device, other.device
496            );
497        }
498    }
499
500    /// Check if the current tensor is contiguous.
501    ///
502    /// A tensor is contiguous if the elements are stored in memory
503    /// if the strides in non-increasing order and the
504    /// strides at position k is equal to the product of the shapes
505    /// at all positions greater than k. However, all axes with a shape of 1 are ignored.
506    pub fn is_contiguous(&self) -> bool {
507        is_contiguous(&self.shape.dims, &self.strides)
508    }
509
510    /// Check if the current tensor has a contiguous backing buffer (no overlap and no empty memory
511    /// regions within the shape).
512    pub fn is_contiguous_buffer(&self) -> bool {
513        self.shape.num_elements() * self.dtype.size() == self.handle.size() as usize
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520
521    #[test]
522    fn is_contiguous_non_increasing() {
523        assert!(is_contiguous(&[3, 1], &[1, 1]));
524    }
525
526    #[test]
527    fn is_contiguous_basic() {
528        assert!(is_contiguous(&[32, 32], &[32, 1]));
529    }
530
531    #[test]
532    fn is_contiguous_permuted() {
533        assert!(!is_contiguous(&[32, 32], &[1, 32]));
534    }
535
536    #[test]
537    fn is_contiguous_slice() {
538        assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));
539    }
540
541    #[test]
542    fn is_contiguous_4d_positive() {
543        assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));
544    }
545
546    #[test]
547    fn is_contiguous_4d_negative() {
548        assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));
549    }
550
551    /// Based on a bug encountered in interpolate_1d
552    #[test]
553    fn is_contiguous_4d_unit_shape() {
554        assert!(!is_contiguous(&[1, 1, 1, 9], &[72, 1, 72, 8]));
555    }
556}