burn_cubecl/tensor/
base.rs

1use crate::CubeRuntime;
2use crate::element::CubeElement;
3use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric};
4use burn_std::tensor::is_contiguous;
5use burn_tensor::quantization::QTensorPrimitive;
6use burn_tensor::{DType, Shape, TensorMetadata};
7use cubecl::client::ComputeClient;
8use cubecl::frontend::Numeric;
9use cubecl::prelude::{TensorHandleRef, *};
10use cubecl::server::Handle;
11use cubecl::std::tensor::TensorHandle;
12use std::marker::PhantomData;
13
14use super::QParams;
15
16/// The basic tensor primitive struct.
17pub struct CubeTensor<R: CubeRuntime> {
18    /// Compute client for the [runtime](CubeRuntime).
19    pub client: ComputeClient<R>,
20    /// The buffer where the data are stored.
21    pub handle: Handle,
22    /// The shape of the tensor.
23    pub shape: Shape,
24    /// The device of the tensor.
25    pub device: R::Device,
26    /// The strides of the tensor.
27    pub strides: Vec<usize>,
28    /// The datatype of the tensor.
29    pub dtype: DType,
30    /// Runtime quantization parameters, if applicable
31    pub qparams: Option<QParams>,
32}
33
34impl<R: CubeRuntime> From<CubeTensor<R>> for TensorHandle<R> {
35    fn from(val: CubeTensor<R>) -> Self {
36        TensorHandle::new(
37            val.handle,
38            val.shape.to_vec(),
39            val.strides.to_vec(),
40            val.dtype.into(),
41        )
42    }
43}
44
45impl<R: CubeRuntime> cubecl::tune::AutotuneOutput for CubeTensor<R> {
46    #[cfg(feature = "autotune-checks")]
47    fn check_equivalence(&self, other: Self) {
48        use burn_tensor::Tolerance;
49
50        use crate::ops::into_data_sync;
51
52        match self.dtype {
53            DType::F64 => {
54                let expected = into_data_sync::<R, f64>(self.clone());
55                let actual = into_data_sync::<R, f64>(other);
56                expected.assert_approx_eq::<f64>(&actual, Tolerance::permissive());
57            }
58            DType::F32 | DType::Flex32 => {
59                let expected = into_data_sync::<R, f32>(self.clone());
60                let actual = into_data_sync::<R, f32>(other);
61                expected.assert_approx_eq::<f32>(&actual, Tolerance::permissive());
62            }
63            DType::F16 => {
64                let expected = into_data_sync::<R, half::f16>(self.clone());
65                let actual = into_data_sync::<R, half::f16>(other);
66                expected.assert_approx_eq::<half::f16>(&actual, Tolerance::permissive());
67            }
68            DType::BF16 => {
69                let expected = into_data_sync::<R, half::bf16>(self.clone());
70                let actual = into_data_sync::<R, half::bf16>(other);
71                expected.assert_approx_eq::<half::bf16>(&actual, Tolerance::permissive());
72            }
73            DType::I64 => {
74                let expected = into_data_sync::<R, i64>(self.clone());
75                let actual = into_data_sync::<R, i64>(other);
76                expected.assert_eq(&actual, true);
77            }
78            DType::I32 => {
79                let expected = into_data_sync::<R, i32>(self.clone());
80                let actual = into_data_sync::<R, i32>(other);
81                expected.assert_eq(&actual, true);
82            }
83            DType::I16 => {
84                let expected = into_data_sync::<R, i16>(self.clone());
85                let actual = into_data_sync::<R, i16>(other);
86                expected.assert_eq(&actual, true);
87            }
88            DType::I8 => {
89                let expected = into_data_sync::<R, i8>(self.clone());
90                let actual = into_data_sync::<R, i8>(other);
91                expected.assert_eq(&actual, true);
92            }
93            DType::U64 => {
94                let expected = into_data_sync::<R, u64>(self.clone());
95                let actual = into_data_sync::<R, u64>(other);
96                expected.assert_eq(&actual, true);
97            }
98            DType::U32 => {
99                let expected = into_data_sync::<R, u32>(self.clone());
100                let actual = into_data_sync::<R, u32>(other);
101                expected.assert_eq(&actual, true);
102            }
103            DType::U16 => {
104                let expected = into_data_sync::<R, u16>(self.clone());
105                let actual = into_data_sync::<R, u16>(other);
106                expected.assert_eq(&actual, true);
107            }
108            DType::U8 => {
109                let expected = into_data_sync::<R, u8>(self.clone());
110                let actual = into_data_sync::<R, u8>(other);
111                expected.assert_eq(&actual, true);
112            }
113            DType::Bool => (),
114            DType::QFloat(..) => (),
115        }
116    }
117}
118
119impl<R> core::fmt::Debug for CubeTensor<R>
120where
121    R: CubeRuntime,
122{
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        f.write_fmt(format_args!(
125            "CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}",
126            self.shape,
127            self.device,
128            self.strides,
129            self.dtype.name(),
130            R::name(&self.client),
131        ))
132    }
133}
134
135impl<R> Clone for CubeTensor<R>
136where
137    R: CubeRuntime,
138{
139    fn clone(&self) -> Self {
140        Self {
141            client: self.client.clone(),
142            handle: self.handle.clone(),
143            shape: self.shape.clone(),
144            device: self.device.clone(),
145            strides: self.strides.clone(),
146            dtype: self.dtype,
147            qparams: self.qparams.clone(),
148        }
149    }
150}
151
152impl<R: CubeRuntime> TensorMetadata for CubeTensor<R> {
153    fn dtype(&self) -> DType {
154        self.dtype
155    }
156
157    fn shape(&self) -> Shape {
158        self.shape.clone()
159    }
160
161    fn rank(&self) -> usize {
162        self.shape.num_dims()
163    }
164}
165
166impl<R: CubeRuntime> QTensorPrimitive for CubeTensor<R> {
167    fn scheme(&self) -> &burn_tensor::quantization::QuantScheme {
168        if let DType::QFloat(scheme) = &self.dtype {
169            scheme
170        } else {
171            panic!(
172                "Quantization scheme is not valid for dtype {:?}",
173                self.dtype,
174            )
175        }
176    }
177}
178
179impl<R> CubeTensor<R>
180where
181    R: CubeRuntime,
182{
183    /// Create a new standard tensor
184    pub fn new(
185        client: ComputeClient<R>,
186        handle: Handle,
187        shape: Shape,
188        device: R::Device,
189        strides: Vec<usize>,
190        dtype: DType,
191    ) -> Self {
192        CubeTensor {
193            client,
194            handle,
195            shape,
196            device,
197            strides,
198            dtype,
199            qparams: None,
200        }
201    }
202
203    /// Create a new tensor with a contiguous memory layout.
204    pub fn new_contiguous(
205        client: ComputeClient<R>,
206        device: R::Device,
207        shape: Shape,
208        handle: Handle,
209        dtype: DType,
210    ) -> Self {
211        let ndims = shape.num_dims();
212        let mut strides = vec![0; ndims];
213        let mut current = 1;
214
215        shape
216            .dims
217            .iter()
218            .enumerate()
219            .rev()
220            .for_each(|(index, val)| {
221                strides[index] = current;
222                current *= val;
223            });
224
225        Self {
226            client,
227            handle,
228            shape,
229            strides,
230            device,
231            dtype,
232            qparams: None,
233        }
234    }
235
236    /// Change the context of the current tensor and return the newly transferred tensor.
237    pub fn to_client(&self, client: ComputeClient<R>, device: R::Device) -> Self {
238        let desc = self
239            .handle
240            .copy_descriptor(&self.shape.dims, &self.strides, self.elem_size());
241        let alloc = self.client.to_client_tensor(desc, &client);
242
243        Self {
244            client,
245            handle: alloc.handle,
246            shape: self.shape.clone(),
247            device,
248            strides: alloc.strides,
249            dtype: self.dtype,
250            qparams: self.qparams.clone(),
251        }
252    }
253
254    /// Return the reference to a tensor handle.
255    pub fn as_handle_ref(&self) -> TensorHandleRef<'_, R> {
256        TensorHandleRef {
257            handle: &self.handle,
258            strides: &self.strides,
259            shape: &self.shape.dims,
260            runtime: PhantomData,
261            elem_size: self.elem_size(),
262        }
263    }
264
265    /// Returns the element size of this tensor
266    pub fn elem_size(&self) -> usize {
267        self.dtype.size()
268    }
269
270    /// Return the reference to a tensor argument.
271    pub fn as_tensor_arg<'a>(&'a self, line_size: u8) -> TensorArg<'a, R> {
272        let size = self.dtype.size();
273        let handle: TensorHandleRef<'a, R> = self.as_handle_ref();
274
275        unsafe {
276            TensorArg::from_raw_parts_and_size(
277                handle.handle,
278                handle.strides,
279                handle.shape,
280                line_size,
281                size,
282            )
283        }
284    }
285
286    /// Return the reference to an array argument.
287    pub fn as_array_arg<E: CubeElement>(&self, vectorisation: u8) -> ArrayArg<'_, R> {
288        unsafe {
289            ArrayArg::from_raw_parts::<E>(
290                &self.handle,
291                self.handle.size() as usize / core::mem::size_of::<E>(),
292                vectorisation,
293            )
294        }
295    }
296
297    pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool {
298        if !self.handle.can_mut() || !self.is_contiguous_buffer() {
299            return false;
300        }
301        let ndims = self.shape.num_dims();
302
303        for i in 0..ndims {
304            let shape_lhs = self.shape[i];
305            let shape_rhs = rhs.shape[i];
306
307            // Output tensor will be different from the mutable tensor.
308            if shape_lhs < shape_rhs {
309                return false;
310            }
311        }
312
313        true
314    }
315
316    /// Copy the current tensor.
317    pub fn copy(&self) -> Self {
318        struct Copy;
319
320        #[cube]
321        impl<N: Numeric> NumericUnaryOp<N> for Copy {
322            type Options = ();
323
324            fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
325                input
326            }
327        }
328
329        impl NumericUnaryOpFamily for Copy {
330            type Options = ();
331            type Unary<N: Numeric> = Self;
332        }
333
334        let tensor = self.clone();
335        launch_unary_numeric::<R, Copy, _>(tensor, |_| ())
336    }
337
338    /// Check if the tensor is safe to mutate.
339    pub fn can_mut(&self) -> bool {
340        self.handle.can_mut()
341    }
342
343    /// Assert that both tensors are on the same device.
344    pub fn assert_is_on_same_device(&self, other: &Self) {
345        if self.device != other.device {
346            panic!(
347                "Both tensors should be on the same device {:?} != {:?}",
348                self.device, other.device
349            );
350        }
351    }
352
353    /// Check if the current tensor is contiguous.
354    ///
355    /// A tensor is contiguous if the elements are stored in memory
356    /// if the strides in non-increasing order and the
357    /// strides at position k is equal to the product of the shapes
358    /// at all positions greater than k. However, all axes with a shape of 1 are ignored.
359    pub fn is_contiguous(&self) -> bool {
360        is_contiguous(&self.shape.dims, &self.strides)
361    }
362
363    /// Check if the current tensor has a contiguous backing buffer (no overlap and no empty memory
364    /// regions within the shape).
365    pub fn is_contiguous_buffer(&self) -> bool {
366        self.shape.num_elements() * self.dtype.size() == self.handle.size() as usize
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    #[test]
375    fn is_contiguous_non_increasing() {
376        assert!(is_contiguous(&[3, 1], &[1, 1]));
377    }
378
379    #[test]
380    fn is_contiguous_basic() {
381        assert!(is_contiguous(&[32, 32], &[32, 1]));
382    }
383
384    #[test]
385    fn is_contiguous_permuted() {
386        assert!(!is_contiguous(&[32, 32], &[1, 32]));
387    }
388
389    #[test]
390    fn is_contiguous_slice() {
391        assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));
392    }
393
394    #[test]
395    fn is_contiguous_4d_positive() {
396        assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));
397    }
398
399    #[test]
400    fn is_contiguous_4d_negative() {
401        assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));
402    }
403
404    /// Based on a bug encountered in interpolate_1d
405    #[test]
406    fn is_contiguous_4d_unit_shape() {
407        assert!(!is_contiguous(&[1, 1, 1, 9], &[72, 1, 72, 8]));
408    }
409}