burn_cubecl/tensor/
base.rs

1use crate::CubeRuntime;
2use crate::element::CubeElement;
3use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric};
4use burn_tensor::quantization::QTensorPrimitive;
5use burn_tensor::{DType, Shape, TensorMetadata};
6use cubecl::client::ComputeClient;
7use cubecl::frontend::Numeric;
8use cubecl::linalg::tensor::TensorHandle;
9use cubecl::prelude::{TensorHandleRef, *};
10use cubecl::server::Handle;
11use std::marker::PhantomData;
12
13/// The basic tensor primitive struct.
14#[derive(new)]
15pub struct CubeTensor<R: CubeRuntime> {
16    /// Compute client for the [runtime](CubeRuntime).
17    pub client: ComputeClient<R::Server, R::Channel>,
18    /// The buffer where the data are stored.
19    pub handle: Handle,
20    /// The shape of the tensor.
21    pub shape: Shape,
22    /// The device of the tensor.
23    pub device: R::Device,
24    /// The strides of the tensor.
25    pub strides: Vec<usize>,
26    /// The datatype of the tensor.
27    pub dtype: DType,
28}
29
30impl<R: CubeRuntime, E: CubeElement> From<CubeTensor<R>> for TensorHandle<R, E> {
31    fn from(val: CubeTensor<R>) -> Self {
32        TensorHandle::new(val.handle, val.shape.dims.to_vec(), val.strides.to_vec())
33    }
34}
35
36impl<R: CubeRuntime> cubecl::tune::AutotuneOutput for CubeTensor<R> {
37    #[cfg(feature = "export_tests")]
38    fn check_equivalence(&self, other: Self) {
39        use burn_tensor::Tolerance;
40
41        use crate::ops::into_data_sync;
42
43        match self.dtype {
44            DType::F64 => {
45                let expected = into_data_sync::<R, f64>(self.clone());
46                let actual = into_data_sync::<R, f64>(other);
47                expected.assert_approx_eq::<f64>(&actual, Tolerance::rel_abs(1e-2, 1e-3));
48            }
49            DType::F32 | DType::Flex32 => {
50                let expected = into_data_sync::<R, f32>(self.clone());
51                let actual = into_data_sync::<R, f32>(other);
52                expected.assert_approx_eq::<f32>(&actual, Tolerance::rel_abs(1e-2, 1e-3));
53            }
54            DType::F16 => {
55                let expected = into_data_sync::<R, half::f16>(self.clone());
56                let actual = into_data_sync::<R, half::f16>(other);
57                expected.assert_approx_eq::<half::f16>(&actual, Tolerance::rel_abs(1e-2, 4e-3));
58            }
59            DType::BF16 => {
60                let expected = into_data_sync::<R, half::bf16>(self.clone());
61                let actual = into_data_sync::<R, half::bf16>(other);
62                expected.assert_approx_eq::<half::bf16>(&actual, Tolerance::rel_abs(1e-2, 1e-3));
63            }
64            DType::I64 => {
65                let expected = into_data_sync::<R, i64>(self.clone());
66                let actual = into_data_sync::<R, i64>(other);
67                expected.assert_eq(&actual, true);
68            }
69            DType::I32 => {
70                let expected = into_data_sync::<R, i32>(self.clone());
71                let actual = into_data_sync::<R, i32>(other);
72                expected.assert_eq(&actual, true);
73            }
74            DType::I16 => {
75                let expected = into_data_sync::<R, i16>(self.clone());
76                let actual = into_data_sync::<R, i16>(other);
77                expected.assert_eq(&actual, true);
78            }
79            DType::I8 => {
80                let expected = into_data_sync::<R, i8>(self.clone());
81                let actual = into_data_sync::<R, i8>(other);
82                expected.assert_eq(&actual, true);
83            }
84            DType::U64 => {
85                let expected = into_data_sync::<R, u64>(self.clone());
86                let actual = into_data_sync::<R, u64>(other);
87                expected.assert_eq(&actual, true);
88            }
89            DType::U32 => {
90                let expected = into_data_sync::<R, u32>(self.clone());
91                let actual = into_data_sync::<R, u32>(other);
92                expected.assert_eq(&actual, true);
93            }
94            DType::U16 => {
95                let expected = into_data_sync::<R, u16>(self.clone());
96                let actual = into_data_sync::<R, u16>(other);
97                expected.assert_eq(&actual, true);
98            }
99            DType::U8 => {
100                let expected = into_data_sync::<R, u8>(self.clone());
101                let actual = into_data_sync::<R, u8>(other);
102                expected.assert_eq(&actual, true);
103            }
104            DType::Bool => (),
105            DType::QFloat(..) => (),
106        }
107    }
108}
109
110impl<R> core::fmt::Debug for CubeTensor<R>
111where
112    R: CubeRuntime,
113{
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.write_fmt(format_args!(
116            "CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}",
117            self.shape,
118            self.device,
119            self.strides,
120            self.dtype.name(),
121            R::name(&self.client),
122        ))
123    }
124}
125
126impl<R> Clone for CubeTensor<R>
127where
128    R: CubeRuntime,
129{
130    fn clone(&self) -> Self {
131        Self {
132            client: self.client.clone(),
133            handle: self.handle.clone(),
134            shape: self.shape.clone(),
135            device: self.device.clone(),
136            strides: self.strides.clone(),
137            dtype: self.dtype,
138        }
139    }
140}
141
142impl<R: CubeRuntime> TensorMetadata for CubeTensor<R> {
143    fn dtype(&self) -> DType {
144        self.dtype
145    }
146
147    fn shape(&self) -> Shape {
148        self.shape.clone()
149    }
150}
151
152impl<R: CubeRuntime> QTensorPrimitive for CubeTensor<R> {
153    fn scheme(&self) -> &burn_tensor::quantization::QuantizationScheme {
154        if let DType::QFloat(scheme) = &self.dtype {
155            scheme
156        } else {
157            panic!(
158                "Quantization scheme is not valid for dtype {:?}",
159                self.dtype,
160            )
161        }
162    }
163}
164
165/// Macro to execute a kernel/operation for a given element type.
166///
167/// # Panics
168/// Since there is no automatic type cast at this time, binary operations for different
169/// floating point precision data types will panic with a data type mismatch.
170#[macro_export]
171macro_rules! execute_with_dtype {
172    (float($dtype:expr), $element:ident, $op:expr) => {{
173        match $dtype {
174            burn_tensor::DType::F64 => {
175                type $element = f64;
176                $op
177            }
178            burn_tensor::DType::F32 => {
179                type $element = f32;
180                $op
181            }
182            burn_tensor::DType::Flex32 => {
183                type $element = cubecl::flex32;
184                $op
185            }
186
187            burn_tensor::DType::F16 => {
188                type $element = half::f16;
189                $op
190            }
191            burn_tensor::DType::BF16 => {
192                type $element = half::bf16;
193                $op
194            }
195            _ => unimplemented!("Unsupported dtype {:?}", $dtype),
196        }
197    }};
198
199    (float($lhs_dtype:expr, $rhs_dtype:expr), $element:ident, $op:expr) => {{
200        // NOTE: might be better for floating point binary operations to return a Result instead?
201        if $lhs_dtype != $rhs_dtype {
202            panic!(
203                "Data type mismatch (lhs: {:?}, rhs: {:?})",
204                $lhs_dtype, $rhs_dtype
205            );
206        }
207        execute_with_dtype!(float($lhs_dtype), $element, $op)
208    }};
209    ($dtype:expr, $element:ident, $op:expr) => {{
210        match $dtype {
211            burn_tensor::DType::F64 => {
212                type $element = f64;
213                $op
214            }
215            burn_tensor::DType::F32 => {
216                type $element = f32;
217                $op
218            }
219            burn_tensor::DType::Flex32 => {
220                type $element = cubecl::flex32;
221                $op
222            }
223            burn_tensor::DType::F16 => {
224                type $element = half::f16;
225                $op
226            }
227            burn_tensor::DType::BF16 => {
228                type $element = half::bf16;
229                $op
230            }
231            burn_tensor::DType::U64 => {
232                type $element = u64;
233                $op
234            }
235            burn_tensor::DType::U32 => {
236                type $element = u32;
237                $op
238            }
239            burn_tensor::DType::U16 => {
240                type $element = u16;
241                $op
242            }
243            burn_tensor::DType::U8 => {
244                type $element = u8;
245                $op
246            }
247            burn_tensor::DType::I64 => {
248                type $element = i64;
249                $op
250            }
251            burn_tensor::DType::I32 => {
252                type $element = i32;
253                $op
254            }
255            burn_tensor::DType::I16 => {
256                type $element = i16;
257                $op
258            }
259            burn_tensor::DType::I8 => {
260                type $element = i8;
261                $op
262            }
263            // NOTE: bool and qfloat dtypes are actually represented as u32/u8
264            // burn_tensor::DType::Bool => {
265            //     type $element = u32/u8;
266            //     $op
267            // }
268            burn_tensor::DType::QFloat(_) => {
269                type $element = u32;
270                $op
271            }
272            _ => unimplemented!("Unsupported dtype {:?}", $dtype),
273        }
274    }};
275}
276
277impl<R> CubeTensor<R>
278where
279    R: CubeRuntime,
280{
281    /// Create a new tensor with a contiguous memory layout.
282    pub fn new_contiguous(
283        client: ComputeClient<R::Server, R::Channel>,
284        device: R::Device,
285        shape: Shape,
286        handle: Handle,
287        dtype: DType,
288    ) -> Self {
289        let ndims = shape.num_dims();
290        let mut strides = vec![0; ndims];
291        let mut current = 1;
292
293        shape
294            .dims
295            .iter()
296            .enumerate()
297            .rev()
298            .for_each(|(index, val)| {
299                strides[index] = current;
300                current *= val;
301            });
302
303        Self {
304            client,
305            handle,
306            shape,
307            strides,
308            device,
309            dtype,
310        }
311    }
312
313    /// Change the context of the current tensor and return the newly transferred tensor.
314    pub fn to_client(
315        &self,
316        client: ComputeClient<R::Server, R::Channel>,
317        device: R::Device,
318    ) -> Self {
319        let bytes = burn_common::reader::try_read_sync(
320            self.client.read_one_async(self.handle.clone().binding()),
321        )
322        .expect("Can only change client synchronously");
323        let handle = client.create(&bytes);
324
325        Self {
326            client,
327            handle,
328            shape: self.shape.clone(),
329            strides: self.strides.clone(),
330            device,
331            dtype: self.dtype,
332        }
333    }
334
335    /// Return the reference to a tensor handle.
336    pub fn as_handle_ref(&self) -> TensorHandleRef<'_, R> {
337        TensorHandleRef {
338            handle: &self.handle,
339            strides: &self.strides,
340            shape: &self.shape.dims,
341            runtime: PhantomData,
342            elem_size: self.elem_size(),
343        }
344    }
345
346    fn elem_size(&self) -> usize {
347        if let DType::QFloat(_) = self.dtype {
348            // Encoded as u32
349            core::mem::size_of::<u32>()
350        } else {
351            self.dtype.size()
352        }
353    }
354
355    /// Return the reference to a tensor argument.
356    pub fn as_tensor_arg<'a, E: CubeElement>(&'a self, vectorisation: u8) -> TensorArg<'a, R> {
357        let handle: TensorHandleRef<'a, R> = self.as_handle_ref();
358
359        unsafe {
360            TensorArg::from_raw_parts::<E>(
361                handle.handle,
362                handle.strides,
363                handle.shape,
364                vectorisation,
365            )
366        }
367    }
368
369    /// Return the reference to an array argument.
370    pub fn as_array_arg<E: CubeElement>(&self, vectorisation: u8) -> ArrayArg<'_, R> {
371        unsafe {
372            ArrayArg::from_raw_parts::<E>(
373                &self.handle,
374                self.handle.size() as usize / core::mem::size_of::<E>(),
375                vectorisation,
376            )
377        }
378    }
379
380    pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool {
381        if !self.handle.can_mut() || !self.is_contiguous_buffer() {
382            return false;
383        }
384        let ndims = self.shape.num_dims();
385
386        for i in 0..ndims {
387            let shape_lhs = self.shape.dims[i];
388            let shape_rhs = rhs.shape.dims[i];
389
390            // Output tensor will be different from the mutable tensor.
391            if shape_lhs < shape_rhs {
392                return false;
393            }
394        }
395
396        true
397    }
398
399    /// Copy the current tensor.
400    pub fn copy(&self) -> Self {
401        struct Copy;
402
403        #[cube]
404        impl<N: Numeric> NumericUnaryOp<N> for Copy {
405            type Options = ();
406
407            fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
408                input
409            }
410        }
411
412        impl NumericUnaryOpFamily for Copy {
413            type Options<N: Numeric> = ();
414            type Unary<N: Numeric> = Self;
415        }
416
417        let tensor = self.clone();
418
419        execute_with_dtype!(
420            tensor.dtype,
421            E,
422            launch_unary_numeric::<R, E, Copy, _>(tensor, |_| ())
423        )
424    }
425
426    /// Check if the tensor is safe to mutate.
427    pub fn can_mut(&self) -> bool {
428        self.handle.can_mut()
429    }
430
431    /// Assert that both tensors are on the same device.
432    pub fn assert_is_on_same_device(&self, other: &Self) {
433        if self.device != other.device {
434            panic!(
435                "Both tensors should be on the same device {:?} != {:?}",
436                self.device, other.device
437            );
438        }
439    }
440
441    /// Check if the current tensor is contiguous.
442    ///
443    /// A tensor is contiguous if the elements are stored
444    /// if the strides in strict decreasing order and the
445    /// strides at position k is equal to the product of the shapes
446    /// at all positions greater than k. However, all axes with a shape of 1 are ignored.
447    pub fn is_contiguous(&self) -> bool {
448        is_contiguous(&self.shape.dims, &self.strides)
449    }
450
451    /// Check if the current tensor has a contiguous backing buffer (no overlap and no empty memory
452    /// regions within the shape).
453    pub fn is_contiguous_buffer(&self) -> bool {
454        self.shape.num_elements() * self.dtype.size() == self.handle.size() as usize
455    }
456}
457
458pub(crate) fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
459    if shape.is_empty() {
460        return true;
461    }
462
463    let mut prev_stride = 1;
464    let mut current_num_elems_shape = 1;
465
466    for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() {
467        if *shape == 1 {
468            continue;
469        }
470
471        if i > 0 {
472            if current_num_elems_shape != *stride {
473                return false;
474            }
475
476            if prev_stride >= *stride {
477                return false;
478            }
479        }
480
481        current_num_elems_shape *= shape;
482        prev_stride = *stride;
483    }
484
485    true
486}
487
488#[cfg(test)]
489mod tests {
490    use crate::tensor::base::is_contiguous;
491
492    #[test]
493    fn is_contiguous_basic() {
494        assert!(is_contiguous(&[32, 32], &[32, 1]));
495    }
496
497    #[test]
498    fn is_contiguous_permuted() {
499        assert!(!is_contiguous(&[32, 32], &[1, 32]));
500    }
501
502    #[test]
503    fn is_contiguous_slice() {
504        assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));
505    }
506
507    #[test]
508    fn is_contiguous_4d_positive() {
509        assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));
510    }
511
512    #[test]
513    fn is_contiguous_4d_negative() {
514        assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));
515    }
516}