burn_cubecl/tensor/
base.rs

1use crate::CubeRuntime;
2use crate::element::CubeElement;
3use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric};
4use burn_common::tensor::is_contiguous;
5use burn_tensor::quantization::QTensorPrimitive;
6use burn_tensor::{DType, Shape, TensorMetadata};
7use cubecl::client::ComputeClient;
8use cubecl::frontend::Numeric;
9use cubecl::prelude::{TensorHandleRef, *};
10use cubecl::server::Handle;
11use cubecl::std::tensor::TensorHandle;
12use std::marker::PhantomData;
13
14/// The basic tensor primitive struct.
15#[derive(new)]
16pub struct CubeTensor<R: CubeRuntime> {
17    /// Compute client for the [runtime](CubeRuntime).
18    pub client: ComputeClient<R::Server, R::Channel>,
19    /// The buffer where the data are stored.
20    pub handle: Handle,
21    /// The shape of the tensor.
22    pub shape: Shape,
23    /// The device of the tensor.
24    pub device: R::Device,
25    /// The strides of the tensor.
26    pub strides: Vec<usize>,
27    /// The datatype of the tensor.
28    pub dtype: DType,
29}
30
31impl<R: CubeRuntime, E: CubeElement> From<CubeTensor<R>> for TensorHandle<R, E> {
32    fn from(val: CubeTensor<R>) -> Self {
33        TensorHandle::new(val.handle, val.shape.dims.to_vec(), val.strides.to_vec())
34    }
35}
36
37impl<R: CubeRuntime> cubecl::tune::AutotuneOutput for CubeTensor<R> {
38    #[cfg(feature = "autotune-checks")]
39    fn check_equivalence(&self, other: Self) {
40        use burn_tensor::Tolerance;
41
42        use crate::ops::into_data_sync;
43
44        match self.dtype {
45            DType::F64 => {
46                let expected = into_data_sync::<R, f64>(self.clone());
47                let actual = into_data_sync::<R, f64>(other);
48                expected.assert_approx_eq::<f64>(&actual, Tolerance::permissive());
49            }
50            DType::F32 | DType::Flex32 => {
51                let expected = into_data_sync::<R, f32>(self.clone());
52                let actual = into_data_sync::<R, f32>(other);
53                expected.assert_approx_eq::<f32>(&actual, Tolerance::permissive());
54            }
55            DType::F16 => {
56                let expected = into_data_sync::<R, half::f16>(self.clone());
57                let actual = into_data_sync::<R, half::f16>(other);
58                expected.assert_approx_eq::<half::f16>(&actual, Tolerance::permissive());
59            }
60            DType::BF16 => {
61                let expected = into_data_sync::<R, half::bf16>(self.clone());
62                let actual = into_data_sync::<R, half::bf16>(other);
63                expected.assert_approx_eq::<half::bf16>(&actual, Tolerance::permissive());
64            }
65            DType::I64 => {
66                let expected = into_data_sync::<R, i64>(self.clone());
67                let actual = into_data_sync::<R, i64>(other);
68                expected.assert_eq(&actual, true);
69            }
70            DType::I32 => {
71                let expected = into_data_sync::<R, i32>(self.clone());
72                let actual = into_data_sync::<R, i32>(other);
73                expected.assert_eq(&actual, true);
74            }
75            DType::I16 => {
76                let expected = into_data_sync::<R, i16>(self.clone());
77                let actual = into_data_sync::<R, i16>(other);
78                expected.assert_eq(&actual, true);
79            }
80            DType::I8 => {
81                let expected = into_data_sync::<R, i8>(self.clone());
82                let actual = into_data_sync::<R, i8>(other);
83                expected.assert_eq(&actual, true);
84            }
85            DType::U64 => {
86                let expected = into_data_sync::<R, u64>(self.clone());
87                let actual = into_data_sync::<R, u64>(other);
88                expected.assert_eq(&actual, true);
89            }
90            DType::U32 => {
91                let expected = into_data_sync::<R, u32>(self.clone());
92                let actual = into_data_sync::<R, u32>(other);
93                expected.assert_eq(&actual, true);
94            }
95            DType::U16 => {
96                let expected = into_data_sync::<R, u16>(self.clone());
97                let actual = into_data_sync::<R, u16>(other);
98                expected.assert_eq(&actual, true);
99            }
100            DType::U8 => {
101                let expected = into_data_sync::<R, u8>(self.clone());
102                let actual = into_data_sync::<R, u8>(other);
103                expected.assert_eq(&actual, true);
104            }
105            DType::Bool => (),
106            DType::QFloat(..) => (),
107        }
108    }
109}
110
111impl<R> core::fmt::Debug for CubeTensor<R>
112where
113    R: CubeRuntime,
114{
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        f.write_fmt(format_args!(
117            "CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}",
118            self.shape,
119            self.device,
120            self.strides,
121            self.dtype.name(),
122            R::name(&self.client),
123        ))
124    }
125}
126
127impl<R> Clone for CubeTensor<R>
128where
129    R: CubeRuntime,
130{
131    fn clone(&self) -> Self {
132        Self {
133            client: self.client.clone(),
134            handle: self.handle.clone(),
135            shape: self.shape.clone(),
136            device: self.device.clone(),
137            strides: self.strides.clone(),
138            dtype: self.dtype,
139        }
140    }
141}
142
143impl<R: CubeRuntime> TensorMetadata for CubeTensor<R> {
144    fn dtype(&self) -> DType {
145        self.dtype
146    }
147
148    fn shape(&self) -> Shape {
149        self.shape.clone()
150    }
151}
152
153impl<R: CubeRuntime> QTensorPrimitive for CubeTensor<R> {
154    fn scheme(&self) -> &burn_tensor::quantization::QuantScheme {
155        if let DType::QFloat(scheme) = &self.dtype {
156            scheme
157        } else {
158            panic!(
159                "Quantization scheme is not valid for dtype {:?}",
160                self.dtype,
161            )
162        }
163    }
164}
165
166/// Macro to execute a kernel/operation for a given element type.
167///
168/// # Panics
169/// Since there is no automatic type cast at this time, binary operations for different
170/// floating point precision data types will panic with a data type mismatch.
171#[macro_export]
172macro_rules! execute_with_dtype {
173    (float($dtype:expr), $element:ident, $op:expr) => {{
174        match $dtype {
175            burn_tensor::DType::F64 => {
176                type $element = f64;
177                $op
178            }
179            burn_tensor::DType::F32 => {
180                type $element = f32;
181                $op
182            }
183            burn_tensor::DType::Flex32 => {
184                type $element = cubecl::flex32;
185                $op
186            }
187
188            burn_tensor::DType::F16 => {
189                type $element = half::f16;
190                $op
191            }
192            burn_tensor::DType::BF16 => {
193                type $element = half::bf16;
194                $op
195            }
196            _ => unimplemented!("Unsupported dtype {:?}", $dtype),
197        }
198    }};
199
200    (float($lhs_dtype:expr, $rhs_dtype:expr), $element:ident, $op:expr) => {{
201        // NOTE: might be better for floating point binary operations to return a Result instead?
202        if $lhs_dtype != $rhs_dtype {
203            panic!(
204                "Data type mismatch (lhs: {:?}, rhs: {:?})",
205                $lhs_dtype, $rhs_dtype
206            );
207        }
208        execute_with_dtype!(float($lhs_dtype), $element, $op)
209    }};
210    ($dtype:expr, $element:ident, $op:expr) => {{
211        match $dtype {
212            burn_tensor::DType::F64 => {
213                type $element = f64;
214                $op
215            }
216            burn_tensor::DType::F32 => {
217                type $element = f32;
218                $op
219            }
220            burn_tensor::DType::Flex32 => {
221                type $element = cubecl::flex32;
222                $op
223            }
224            burn_tensor::DType::F16 => {
225                type $element = half::f16;
226                $op
227            }
228            burn_tensor::DType::BF16 => {
229                type $element = half::bf16;
230                $op
231            }
232            burn_tensor::DType::U64 => {
233                type $element = u64;
234                $op
235            }
236            burn_tensor::DType::U32 => {
237                type $element = u32;
238                $op
239            }
240            burn_tensor::DType::U16 => {
241                type $element = u16;
242                $op
243            }
244            burn_tensor::DType::U8 => {
245                type $element = u8;
246                $op
247            }
248            burn_tensor::DType::I64 => {
249                type $element = i64;
250                $op
251            }
252            burn_tensor::DType::I32 => {
253                type $element = i32;
254                $op
255            }
256            burn_tensor::DType::I16 => {
257                type $element = i16;
258                $op
259            }
260            burn_tensor::DType::I8 => {
261                type $element = i8;
262                $op
263            }
264            // NOTE: bool and qfloat dtypes are actually represented as u32/u8
265            // burn_tensor::DType::Bool => {
266            //     type $element = u32/u8;
267            //     $op
268            // }
269            burn_tensor::DType::QFloat(_) => {
270                type $element = u32;
271                $op
272            }
273            _ => unimplemented!("Unsupported dtype {:?}", $dtype),
274        }
275    }};
276}
277
278impl<R> CubeTensor<R>
279where
280    R: CubeRuntime,
281{
282    /// Create a new tensor with a contiguous memory layout.
283    pub fn new_contiguous(
284        client: ComputeClient<R::Server, R::Channel>,
285        device: R::Device,
286        shape: Shape,
287        handle: Handle,
288        dtype: DType,
289    ) -> Self {
290        let ndims = shape.num_dims();
291        let mut strides = vec![0; ndims];
292        let mut current = 1;
293
294        shape
295            .dims
296            .iter()
297            .enumerate()
298            .rev()
299            .for_each(|(index, val)| {
300                strides[index] = current;
301                current *= val;
302            });
303
304        Self {
305            client,
306            handle,
307            shape,
308            strides,
309            device,
310            dtype,
311        }
312    }
313
314    /// Change the context of the current tensor and return the newly transferred tensor.
315    pub fn to_client(
316        &self,
317        client: ComputeClient<R::Server, R::Channel>,
318        device: R::Device,
319    ) -> Self {
320        let bytes = self.client.read_one(self.handle.clone().binding());
321        let handle = client.create(&bytes);
322
323        Self {
324            client,
325            handle,
326            shape: self.shape.clone(),
327            strides: self.strides.clone(),
328            device,
329            dtype: self.dtype,
330        }
331    }
332
333    /// Return the reference to a tensor handle.
334    pub fn as_handle_ref(&self) -> TensorHandleRef<'_, R> {
335        TensorHandleRef {
336            handle: &self.handle,
337            strides: &self.strides,
338            shape: &self.shape.dims,
339            runtime: PhantomData,
340            elem_size: self.elem_size(),
341        }
342    }
343
344    fn elem_size(&self) -> usize {
345        if let DType::QFloat(_) = self.dtype {
346            // Encoded as u32
347            core::mem::size_of::<u32>()
348        } else {
349            self.dtype.size()
350        }
351    }
352
353    /// Return the reference to a tensor argument.
354    pub fn as_tensor_arg<'a, E: CubeElement>(&'a self, line_size: u8) -> TensorArg<'a, R> {
355        let handle: TensorHandleRef<'a, R> = self.as_handle_ref();
356
357        unsafe {
358            TensorArg::from_raw_parts::<E>(handle.handle, handle.strides, handle.shape, line_size)
359        }
360    }
361
362    /// Return the reference to an array argument.
363    pub fn as_array_arg<E: CubeElement>(&self, vectorisation: u8) -> ArrayArg<'_, R> {
364        unsafe {
365            ArrayArg::from_raw_parts::<E>(
366                &self.handle,
367                self.handle.size() as usize / core::mem::size_of::<E>(),
368                vectorisation,
369            )
370        }
371    }
372
373    pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool {
374        if !self.handle.can_mut() || !self.is_contiguous_buffer() {
375            return false;
376        }
377        let ndims = self.shape.num_dims();
378
379        for i in 0..ndims {
380            let shape_lhs = self.shape.dims[i];
381            let shape_rhs = rhs.shape.dims[i];
382
383            // Output tensor will be different from the mutable tensor.
384            if shape_lhs < shape_rhs {
385                return false;
386            }
387        }
388
389        true
390    }
391
392    /// Copy the current tensor.
393    pub fn copy(&self) -> Self {
394        struct Copy;
395
396        #[cube]
397        impl<N: Numeric> NumericUnaryOp<N> for Copy {
398            type Options = ();
399
400            fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
401                input
402            }
403        }
404
405        impl NumericUnaryOpFamily for Copy {
406            type Options<N: Numeric> = ();
407            type Unary<N: Numeric> = Self;
408        }
409
410        let tensor = self.clone();
411
412        execute_with_dtype!(
413            tensor.dtype,
414            E,
415            launch_unary_numeric::<R, E, Copy, _>(tensor, |_| ())
416        )
417    }
418
419    /// Check if the tensor is safe to mutate.
420    pub fn can_mut(&self) -> bool {
421        self.handle.can_mut()
422    }
423
424    /// Assert that both tensors are on the same device.
425    pub fn assert_is_on_same_device(&self, other: &Self) {
426        if self.device != other.device {
427            panic!(
428                "Both tensors should be on the same device {:?} != {:?}",
429                self.device, other.device
430            );
431        }
432    }
433
434    /// Check if the current tensor is contiguous.
435    ///
436    /// A tensor is contiguous if the elements are stored in memory
437    /// if the strides in non-increasing order and the
438    /// strides at position k is equal to the product of the shapes
439    /// at all positions greater than k. However, all axes with a shape of 1 are ignored.
440    pub fn is_contiguous(&self) -> bool {
441        is_contiguous(&self.shape.dims, &self.strides)
442    }
443
444    /// Check if the current tensor has a contiguous backing buffer (no overlap and no empty memory
445    /// regions within the shape).
446    pub fn is_contiguous_buffer(&self) -> bool {
447        self.shape.num_elements() * self.dtype.size() == self.handle.size() as usize
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454
455    #[test]
456    fn is_contiguous_non_increasing() {
457        assert!(is_contiguous(&[3, 1], &[1, 1]));
458    }
459
460    #[test]
461    fn is_contiguous_basic() {
462        assert!(is_contiguous(&[32, 32], &[32, 1]));
463    }
464
465    #[test]
466    fn is_contiguous_permuted() {
467        assert!(!is_contiguous(&[32, 32], &[1, 32]));
468    }
469
470    #[test]
471    fn is_contiguous_slice() {
472        assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));
473    }
474
475    #[test]
476    fn is_contiguous_4d_positive() {
477        assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));
478    }
479
480    #[test]
481    fn is_contiguous_4d_negative() {
482        assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));
483    }
484
485    /// Based on a bug encountered in interpolate_1d
486    #[test]
487    fn is_contiguous_4d_unit_shape() {
488        assert!(!is_contiguous(&[1, 1, 1, 9], &[72, 1, 72, 8]));
489    }
490}