Skip to main content

burn_cubecl/tensor/
base.rs

1use crate::CubeRuntime;
2use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric};
3use burn_backend::quantization::QuantScheme;
4use burn_backend::{DType, QTensorPrimitive, Shape, TensorMetadata};
5use burn_std::{Metadata, strides, tensor::is_contiguous};
6use cubecl::server::Handle;
7use cubecl::std::tensor::TensorHandle;
8use cubecl::{client::ComputeClient, std::tensor::layout::linear::LinearViewLaunch};
9use cubecl::{frontend::Numeric, std::tensor::layout::linear::LinearViewLayoutLaunch};
10use cubecl::{
11    prelude::{TensorBinding, *},
12    std::tensor::layout::linear::LinearViewLayout,
13};
14use std::marker::PhantomData;
15
16use super::QParams;
17
18/// The basic tensor primitive struct.
19pub struct CubeTensor<R: CubeRuntime> {
20    /// Compute client for the [runtime](CubeRuntime).
21    pub client: ComputeClient<R>,
22    /// The buffer where the data are stored.
23    pub handle: Handle,
24    /// The metadata of the tensor.
25    pub meta: Box<Metadata>,
26    /// The device of the tensor.
27    pub device: R::Device,
28    /// The datatype of the tensor.
29    pub dtype: DType,
30    /// Runtime quantization parameters, if applicable
31    pub qparams: Option<QParams>,
32}
33
34impl<R: CubeRuntime> From<CubeTensor<R>> for TensorHandle<R> {
35    fn from(val: CubeTensor<R>) -> Self {
36        TensorHandle::new(
37            val.handle.clone(),
38            val.meta.shape().clone(),
39            val.meta.strides().clone(),
40            val.dtype,
41        )
42    }
43}
44
45impl<R: CubeRuntime> cubecl::tune::AutotuneOutput for CubeTensor<R> {
46    #[cfg(feature = "autotune-checks")]
47    fn check_equivalence(&self, other: Self) {
48        use crate::ops::into_data_sync;
49        use burn_backend::Tolerance;
50
51        let expected = into_data_sync::<R>(self.clone());
52        let actual = into_data_sync::<R>(other);
53        expected.assert_approx_eq::<f32>(&actual, Tolerance::permissive());
54    }
55}
56
57// TODO: Needed to cleanup leaves tensor.
58//
59// Maybe not needed when fusion is activated, since we have a detector there.
60// We could rely on basic GC strategy when not using fusion.
61//
62// impl<R: CubeRuntime> Drop for CubeTensor<R> {
63//     fn drop(&mut self) {
64//         todo!()
65//     }
66// }
67
68impl<R> core::fmt::Debug for CubeTensor<R>
69where
70    R: CubeRuntime,
71{
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        f.write_fmt(format_args!(
74            "CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}",
75            self.meta.shape(),
76            self.device,
77            self.meta.strides(),
78            self.dtype.name(),
79            R::name(&self.client),
80        ))
81    }
82}
83
84impl<R> Clone for CubeTensor<R>
85where
86    R: CubeRuntime,
87{
88    fn clone(&self) -> Self {
89        Self {
90            client: self.client.clone(),
91            handle: self.handle.clone(),
92            meta: self.meta.clone(),
93            device: self.device.clone(),
94            dtype: self.dtype,
95            qparams: self.qparams.clone(),
96        }
97    }
98}
99
100impl<R: CubeRuntime> TensorMetadata for CubeTensor<R> {
101    fn dtype(&self) -> DType {
102        self.dtype
103    }
104
105    fn shape(&self) -> Shape {
106        self.meta.shape().clone()
107    }
108
109    fn rank(&self) -> usize {
110        self.meta.rank()
111    }
112}
113
114impl<R: CubeRuntime> QTensorPrimitive for CubeTensor<R> {
115    fn scheme(&self) -> &QuantScheme {
116        if let DType::QFloat(scheme) = &self.dtype {
117            scheme
118        } else {
119            panic!(
120                "Quantization scheme is not valid for dtype {:?}",
121                self.dtype,
122            )
123        }
124    }
125}
126
127impl<R> CubeTensor<R>
128where
129    R: CubeRuntime,
130{
131    /// Create a new standard tensor
132    pub fn new(
133        client: ComputeClient<R>,
134        handle: Handle,
135        metadata: Metadata,
136        device: R::Device,
137        dtype: DType,
138    ) -> Self {
139        CubeTensor {
140            client,
141            handle,
142            meta: Box::new(metadata),
143            device,
144            dtype,
145            qparams: None,
146        }
147    }
148
149    /// Create a new tensor with a contiguous memory layout.
150    pub fn new_contiguous(
151        client: ComputeClient<R>,
152        device: R::Device,
153        shape: Shape,
154        handle: Handle,
155        dtype: DType,
156    ) -> Self {
157        let ndims = shape.num_dims();
158        let mut strides = strides![0; ndims];
159        let mut current = 1;
160
161        shape.iter().enumerate().rev().for_each(|(index, val)| {
162            strides[index] = current;
163            current *= val;
164        });
165
166        Self {
167            client,
168            handle,
169            meta: Box::new(Metadata::new(shape, strides)),
170            device,
171            dtype,
172            qparams: None,
173        }
174    }
175
176    /// Change the context of the current tensor and return the newly transferred tensor.
177    pub fn to_client(&mut self, client: ComputeClient<R>, device: R::Device) -> Self {
178        let desc = self.handle.clone().copy_descriptor(
179            self.meta.shape().clone(),
180            self.meta.strides().clone(),
181            self.elem_size(),
182        );
183        let handle = self
184            .client
185            .to_client_tensor(desc, &client, self.dtype.into());
186
187        Self {
188            client,
189            handle,
190            meta: Box::new(Metadata::new(self.shape(), self.meta.strides().clone())),
191            device,
192            dtype: self.dtype,
193            qparams: self.qparams.clone(),
194        }
195    }
196
197    /// Return the reference to a tensor handle.
198    pub fn binding(self) -> TensorBinding<R> {
199        TensorBinding {
200            handle: self.handle.binding(),
201            strides: self.meta.strides,
202            shape: self.meta.shape,
203            runtime: PhantomData,
204        }
205    }
206
207    /// Returns the element size of this tensor
208    pub fn elem_size(&self) -> usize {
209        self.dtype.size()
210    }
211
212    /// Return the reference to a tensor argument.
213    pub fn into_tensor_arg(self) -> TensorArg<R> {
214        self.binding().into_tensor_arg()
215    }
216
217    /// Return the reference to an array argument.
218    pub fn into_array_arg(self) -> ArrayArg<R> {
219        self.into_tensor_arg().into_array_arg()
220    }
221
222    /// Returns a reference to the aliased tensor argument.
223    pub fn as_tensor_alias(&self, input_pos: usize) -> TensorArg<R> {
224        TensorArg::Alias {
225            input_pos,
226            strides: self.meta.strides().clone(),
227            shape: self.meta.shape().clone(),
228        }
229    }
230
231    /// Return a linear view of this tensor.
232    pub fn into_linear_view(self) -> LinearViewLaunch<R> {
233        let layout = LinearViewLayoutLaunch::new();
234        let buffer = self.into_tensor_arg();
235        LinearViewLaunch::new_tensor::<LinearViewLayout>(buffer, layout)
236    }
237
238    /// Return an aliased linear view of this tensor
239    pub fn as_linear_view_alias(&self, input_pos: usize) -> LinearViewLaunch<R> {
240        let layout = LinearViewLayoutLaunch::new();
241        let buffer = self.as_tensor_alias(input_pos);
242        LinearViewLaunch::new_tensor::<LinearViewLayout>(buffer, layout)
243    }
244
245    /// Return a linear view broadcast to the reference tensor's shape
246    pub fn into_linear_view_like(self, reference: &Self) -> LinearViewLaunch<R> {
247        let layout = LinearViewLayoutLaunch::from_reference_shape(reference.shape());
248        let buffer = self.into_tensor_arg();
249        LinearViewLaunch::new_tensor::<LinearViewLayout>(buffer, layout)
250    }
251
252    /// Returns the address type required to index this tensor
253    pub fn required_address_type(&self) -> AddressType {
254        match self.try_scheme() {
255            Some(scheme) => {
256                let len = self.handle.size() as usize * 8 / scheme.size_bits_value();
257                AddressType::from_len(len)
258            }
259            None => AddressType::from_len(self.handle.size() as usize / self.dtype.size()),
260        }
261    }
262
263    /// Return the `QuantScheme` if present
264    pub fn try_scheme(&self) -> Option<&QuantScheme> {
265        match &self.dtype {
266            DType::QFloat(scheme) => Some(scheme),
267            _ => None,
268        }
269    }
270
271    pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool {
272        if !self.handle.can_mut() || !self.is_nonoverlapping() {
273            return false;
274        }
275        let ndims = self.meta.num_dims();
276
277        for i in 0..ndims {
278            let shape_lhs = self.meta.shape()[i];
279            let shape_rhs = rhs.meta.shape()[i];
280
281            // Output tensor will be different from the mutable tensor.
282            if shape_lhs < shape_rhs {
283                return false;
284            }
285        }
286
287        true
288    }
289
290    /// Copy the current tensor.
291    pub fn copy(&self) -> Self {
292        struct Copy;
293
294        #[cube]
295        impl<T: Numeric, N: Size> NumericUnaryOp<T, N> for Copy {
296            type Options = ();
297
298            fn execute(input: Vector<T, N>, _options: &Self::Options) -> Vector<T, N> {
299                input
300            }
301        }
302
303        impl NumericUnaryOpFamily for Copy {
304            type Options = ();
305            type Unary<T: Numeric, N: Size> = Self;
306        }
307
308        let tensor = self.clone();
309        launch_unary_numeric::<R, Copy, _>(tensor, |_| ())
310    }
311
312    /// Check if the tensor is safe to mutate.
313    pub fn can_mut(&self) -> bool {
314        self.handle.can_mut()
315    }
316
317    /// Assert that both tensors are on the same device.
318    pub fn assert_is_on_same_device(&self, other: &Self) {
319        if self.device != other.device {
320            panic!(
321                "Both tensors should be on the same device {:?} != {:?}",
322                self.device, other.device
323            );
324        }
325    }
326
327    /// Check if the current tensor is contiguous.
328    ///
329    /// A tensor is contiguous if the elements are stored in memory
330    /// if the strides in non-increasing order and the
331    /// strides at position k is equal to the product of the shapes
332    /// at all positions greater than k. However, all axes with a shape of 1 are ignored.
333    pub fn is_contiguous(&self) -> bool {
334        is_contiguous(self.meta.shape(), self.meta.strides())
335    }
336
337    /// Check if the current tensor has a contiguous backing buffer (no overlap and no empty memory
338    /// regions within the shape).
339    pub fn is_contiguous_buffer(&self) -> bool {
340        self.meta.shape().num_elements() * self.dtype.size() == self.handle.size() as usize
341    }
342
343    /// Checks if the tensor is non-overlapping (can be safely written to).
344    pub fn is_nonoverlapping(&self) -> bool {
345        let shape = self.meta.shape();
346        let strides = self.meta.strides();
347
348        if strides.contains(&0) {
349            return false;
350        }
351        let rank = self.rank();
352        if rank > 1 {
353            let mut dims = shape.iter().zip(strides.iter()).collect::<Vec<_>>();
354            dims.sort_by_key(|(_, stride)| **stride);
355
356            let mut max_offset = 0;
357            for (shape, stride) in dims.into_iter() {
358                if *stride <= max_offset && *shape != 1 {
359                    return false;
360                }
361
362                max_offset += (*shape - 1) * *stride;
363            }
364        }
365        true
366    }
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    #[test]
374    fn is_contiguous_non_increasing() {
375        assert!(is_contiguous(&[3, 1], &[1, 1]));
376    }
377
378    #[test]
379    fn is_contiguous_basic() {
380        assert!(is_contiguous(&[32, 32], &[32, 1]));
381    }
382
383    #[test]
384    fn is_contiguous_permuted() {
385        assert!(!is_contiguous(&[32, 32], &[1, 32]));
386    }
387
388    #[test]
389    fn is_contiguous_slice() {
390        assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));
391    }
392
393    #[test]
394    fn is_contiguous_4d_positive() {
395        assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));
396    }
397
398    #[test]
399    fn is_contiguous_4d_negative() {
400        assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));
401    }
402
403    /// Based on a bug encountered in interpolate_1d
404    #[test]
405    fn is_contiguous_4d_unit_shape() {
406        assert!(!is_contiguous(&[1, 1, 1, 9], &[72, 1, 72, 8]));
407    }
408}