burn_cubecl/tensor/
base.rs

1use crate::CubeRuntime;
2use crate::element::CubeElement;
3use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric};
4use burn_common::tensor::is_contiguous;
5use burn_tensor::quantization::QTensorPrimitive;
6use burn_tensor::{DType, Shape, TensorMetadata};
7use cubecl::client::ComputeClient;
8use cubecl::frontend::Numeric;
9use cubecl::prelude::{TensorHandleRef, *};
10use cubecl::server::Handle;
11use cubecl::std::tensor::TensorHandle;
12use std::marker::PhantomData;
13
14use super::QParams;
15
16/// The basic tensor primitive struct.
17pub struct CubeTensor<R: CubeRuntime> {
18    /// Compute client for the [runtime](CubeRuntime).
19    pub client: ComputeClient<R::Server>,
20    /// The buffer where the data are stored.
21    pub handle: Handle,
22    /// The shape of the tensor.
23    pub shape: Shape,
24    /// The device of the tensor.
25    pub device: R::Device,
26    /// The strides of the tensor.
27    pub strides: Vec<usize>,
28    /// The datatype of the tensor.
29    pub dtype: DType,
30    /// Runtime quantization parameters, if applicable
31    pub qparams: Option<QParams>,
32}
33
34impl<R: CubeRuntime> From<CubeTensor<R>> for TensorHandle<R> {
35    fn from(val: CubeTensor<R>) -> Self {
36        TensorHandle::new(
37            val.handle,
38            val.shape.to_vec(),
39            val.strides.to_vec(),
40            val.dtype.into(),
41        )
42    }
43}
44
45impl<R: CubeRuntime> cubecl::tune::AutotuneOutput for CubeTensor<R> {
46    #[cfg(feature = "autotune-checks")]
47    fn check_equivalence(&self, other: Self) {
48        use burn_tensor::Tolerance;
49
50        use crate::ops::into_data_sync;
51
52        match self.dtype {
53            DType::F64 => {
54                let expected = into_data_sync::<R, f64>(self.clone());
55                let actual = into_data_sync::<R, f64>(other);
56                expected.assert_approx_eq::<f64>(&actual, Tolerance::permissive());
57            }
58            DType::F32 | DType::Flex32 => {
59                let expected = into_data_sync::<R, f32>(self.clone());
60                let actual = into_data_sync::<R, f32>(other);
61                expected.assert_approx_eq::<f32>(&actual, Tolerance::permissive());
62            }
63            DType::F16 => {
64                let expected = into_data_sync::<R, half::f16>(self.clone());
65                let actual = into_data_sync::<R, half::f16>(other);
66                expected.assert_approx_eq::<half::f16>(&actual, Tolerance::permissive());
67            }
68            DType::BF16 => {
69                let expected = into_data_sync::<R, half::bf16>(self.clone());
70                let actual = into_data_sync::<R, half::bf16>(other);
71                expected.assert_approx_eq::<half::bf16>(&actual, Tolerance::permissive());
72            }
73            DType::I64 => {
74                let expected = into_data_sync::<R, i64>(self.clone());
75                let actual = into_data_sync::<R, i64>(other);
76                expected.assert_eq(&actual, true);
77            }
78            DType::I32 => {
79                let expected = into_data_sync::<R, i32>(self.clone());
80                let actual = into_data_sync::<R, i32>(other);
81                expected.assert_eq(&actual, true);
82            }
83            DType::I16 => {
84                let expected = into_data_sync::<R, i16>(self.clone());
85                let actual = into_data_sync::<R, i16>(other);
86                expected.assert_eq(&actual, true);
87            }
88            DType::I8 => {
89                let expected = into_data_sync::<R, i8>(self.clone());
90                let actual = into_data_sync::<R, i8>(other);
91                expected.assert_eq(&actual, true);
92            }
93            DType::U64 => {
94                let expected = into_data_sync::<R, u64>(self.clone());
95                let actual = into_data_sync::<R, u64>(other);
96                expected.assert_eq(&actual, true);
97            }
98            DType::U32 => {
99                let expected = into_data_sync::<R, u32>(self.clone());
100                let actual = into_data_sync::<R, u32>(other);
101                expected.assert_eq(&actual, true);
102            }
103            DType::U16 => {
104                let expected = into_data_sync::<R, u16>(self.clone());
105                let actual = into_data_sync::<R, u16>(other);
106                expected.assert_eq(&actual, true);
107            }
108            DType::U8 => {
109                let expected = into_data_sync::<R, u8>(self.clone());
110                let actual = into_data_sync::<R, u8>(other);
111                expected.assert_eq(&actual, true);
112            }
113            DType::Bool => (),
114            DType::QFloat(..) => (),
115        }
116    }
117}
118
119impl<R> core::fmt::Debug for CubeTensor<R>
120where
121    R: CubeRuntime,
122{
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        f.write_fmt(format_args!(
125            "CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}",
126            self.shape,
127            self.device,
128            self.strides,
129            self.dtype.name(),
130            R::name(&self.client),
131        ))
132    }
133}
134
135impl<R> Clone for CubeTensor<R>
136where
137    R: CubeRuntime,
138{
139    fn clone(&self) -> Self {
140        Self {
141            client: self.client.clone(),
142            handle: self.handle.clone(),
143            shape: self.shape.clone(),
144            device: self.device.clone(),
145            strides: self.strides.clone(),
146            dtype: self.dtype,
147            qparams: self.qparams.clone(),
148        }
149    }
150}
151
152impl<R: CubeRuntime> TensorMetadata for CubeTensor<R> {
153    fn dtype(&self) -> DType {
154        self.dtype
155    }
156
157    fn shape(&self) -> Shape {
158        self.shape.clone()
159    }
160
161    fn rank(&self) -> usize {
162        self.shape.num_dims()
163    }
164}
165
166impl<R: CubeRuntime> QTensorPrimitive for CubeTensor<R> {
167    fn scheme(&self) -> &burn_tensor::quantization::QuantScheme {
168        if let DType::QFloat(scheme) = &self.dtype {
169            scheme
170        } else {
171            panic!(
172                "Quantization scheme is not valid for dtype {:?}",
173                self.dtype,
174            )
175        }
176    }
177}
178
179impl<R> CubeTensor<R>
180where
181    R: CubeRuntime,
182{
183    /// Create a new standard tensor
184    pub fn new(
185        client: ComputeClient<R::Server>,
186        handle: Handle,
187        shape: Shape,
188        device: R::Device,
189        strides: Vec<usize>,
190        dtype: DType,
191    ) -> Self {
192        CubeTensor {
193            client,
194            handle,
195            shape,
196            device,
197            strides,
198            dtype,
199            qparams: None,
200        }
201    }
202
203    /// Create a new tensor with a contiguous memory layout.
204    pub fn new_contiguous(
205        client: ComputeClient<R::Server>,
206        device: R::Device,
207        shape: Shape,
208        handle: Handle,
209        dtype: DType,
210    ) -> Self {
211        let ndims = shape.num_dims();
212        let mut strides = vec![0; ndims];
213        let mut current = 1;
214
215        shape
216            .dims
217            .iter()
218            .enumerate()
219            .rev()
220            .for_each(|(index, val)| {
221                strides[index] = current;
222                current *= val;
223            });
224
225        Self {
226            client,
227            handle,
228            shape,
229            strides,
230            device,
231            dtype,
232            qparams: None,
233        }
234    }
235
236    /// Change the context of the current tensor and return the newly transferred tensor.
237    pub fn to_client(&self, client: ComputeClient<R::Server>, device: R::Device) -> Self {
238        let desc = self
239            .handle
240            .copy_descriptor(&self.shape.dims, &self.strides, self.elem_size());
241        let alloc = self.client.to_client_tensor(desc, &client);
242
243        Self {
244            client,
245            handle: alloc.handle,
246            shape: self.shape.clone(),
247            device,
248            strides: alloc.strides,
249            dtype: self.dtype,
250            qparams: self.qparams.clone(),
251        }
252    }
253
254    /// Return the reference to a tensor handle.
255    pub fn as_handle_ref(&self) -> TensorHandleRef<'_, R> {
256        TensorHandleRef {
257            handle: &self.handle,
258            strides: &self.strides,
259            shape: &self.shape.dims,
260            runtime: PhantomData,
261            elem_size: self.elem_size(),
262        }
263    }
264
265    /// Returns the element size of this tensor
266    pub fn elem_size(&self) -> usize {
267        if let DType::QFloat(_) = self.dtype {
268            // Encoded as u32
269            core::mem::size_of::<u32>()
270        } else {
271            self.dtype.size()
272        }
273    }
274
275    /// Return the reference to a tensor argument.
276    pub fn as_tensor_arg<'a>(&'a self, line_size: u8) -> TensorArg<'a, R> {
277        let size = self.dtype.size();
278        let handle: TensorHandleRef<'a, R> = self.as_handle_ref();
279
280        unsafe {
281            TensorArg::from_raw_parts_and_size(
282                handle.handle,
283                handle.strides,
284                handle.shape,
285                line_size,
286                size,
287            )
288        }
289    }
290
291    /// Return the reference to an array argument.
292    pub fn as_array_arg<E: CubeElement>(&self, vectorisation: u8) -> ArrayArg<'_, R> {
293        unsafe {
294            ArrayArg::from_raw_parts::<E>(
295                &self.handle,
296                self.handle.size() as usize / core::mem::size_of::<E>(),
297                vectorisation,
298            )
299        }
300    }
301
302    pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool {
303        if !self.handle.can_mut() || !self.is_contiguous_buffer() {
304            return false;
305        }
306        let ndims = self.shape.num_dims();
307
308        for i in 0..ndims {
309            let shape_lhs = self.shape[i];
310            let shape_rhs = rhs.shape[i];
311
312            // Output tensor will be different from the mutable tensor.
313            if shape_lhs < shape_rhs {
314                return false;
315            }
316        }
317
318        true
319    }
320
321    /// Copy the current tensor.
322    pub fn copy(&self) -> Self {
323        struct Copy;
324
325        #[cube]
326        impl<N: Numeric> NumericUnaryOp<N> for Copy {
327            type Options = ();
328
329            fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
330                input
331            }
332        }
333
334        impl NumericUnaryOpFamily for Copy {
335            type Options = ();
336            type Unary<N: Numeric> = Self;
337        }
338
339        let tensor = self.clone();
340        launch_unary_numeric::<R, Copy, _>(tensor, |_| ())
341    }
342
343    /// Check if the tensor is safe to mutate.
344    pub fn can_mut(&self) -> bool {
345        self.handle.can_mut()
346    }
347
348    /// Assert that both tensors are on the same device.
349    pub fn assert_is_on_same_device(&self, other: &Self) {
350        if self.device != other.device {
351            panic!(
352                "Both tensors should be on the same device {:?} != {:?}",
353                self.device, other.device
354            );
355        }
356    }
357
358    /// Check if the current tensor is contiguous.
359    ///
360    /// A tensor is contiguous if the elements are stored in memory
361    /// if the strides in non-increasing order and the
362    /// strides at position k is equal to the product of the shapes
363    /// at all positions greater than k. However, all axes with a shape of 1 are ignored.
364    pub fn is_contiguous(&self) -> bool {
365        is_contiguous(&self.shape.dims, &self.strides)
366    }
367
368    /// Check if the current tensor has a contiguous backing buffer (no overlap and no empty memory
369    /// regions within the shape).
370    pub fn is_contiguous_buffer(&self) -> bool {
371        self.shape.num_elements() * self.dtype.size() == self.handle.size() as usize
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    #[test]
380    fn is_contiguous_non_increasing() {
381        assert!(is_contiguous(&[3, 1], &[1, 1]));
382    }
383
384    #[test]
385    fn is_contiguous_basic() {
386        assert!(is_contiguous(&[32, 32], &[32, 1]));
387    }
388
389    #[test]
390    fn is_contiguous_permuted() {
391        assert!(!is_contiguous(&[32, 32], &[1, 32]));
392    }
393
394    #[test]
395    fn is_contiguous_slice() {
396        assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));
397    }
398
399    #[test]
400    fn is_contiguous_4d_positive() {
401        assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));
402    }
403
404    #[test]
405    fn is_contiguous_4d_negative() {
406        assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));
407    }
408
409    /// Based on a bug encountered in interpolate_1d
410    #[test]
411    fn is_contiguous_4d_unit_shape() {
412        assert!(!is_contiguous(&[1, 1, 1, 9], &[72, 1, 72, 8]));
413    }
414}