burn_cubecl/tensor/
base.rs

1use crate::CubeRuntime;
2use crate::element::CubeElement;
3use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric};
4use burn_backend::quantization::QuantScheme;
5use burn_backend::{DType, QTensorPrimitive, Shape, TensorMetadata};
6use burn_std::tensor::is_contiguous;
7use cubecl::client::ComputeClient;
8use cubecl::frontend::Numeric;
9use cubecl::prelude::{TensorHandleRef, *};
10use cubecl::server::Handle;
11use cubecl::std::tensor::TensorHandle;
12use std::marker::PhantomData;
13
14use super::QParams;
15
16/// The basic tensor primitive struct.
17pub struct CubeTensor<R: CubeRuntime> {
18    /// Compute client for the [runtime](CubeRuntime).
19    pub client: ComputeClient<R>,
20    /// The buffer where the data are stored.
21    pub handle: Handle,
22    /// The shape of the tensor.
23    pub shape: Shape,
24    /// The device of the tensor.
25    pub device: R::Device,
26    /// The strides of the tensor.
27    pub strides: Vec<usize>,
28    /// The datatype of the tensor.
29    pub dtype: DType,
30    /// Runtime quantization parameters, if applicable
31    pub qparams: Option<QParams>,
32}
33
34impl<R: CubeRuntime> From<CubeTensor<R>> for TensorHandle<R> {
35    fn from(val: CubeTensor<R>) -> Self {
36        TensorHandle::new(
37            val.handle,
38            val.shape.to_vec(),
39            val.strides.to_vec(),
40            val.dtype.into(),
41        )
42    }
43}
44
45impl<R: CubeRuntime> cubecl::tune::AutotuneOutput for CubeTensor<R> {
46    #[cfg(feature = "autotune-checks")]
47    fn check_equivalence(&self, other: Self) {
48        use crate::ops::into_data_sync;
49        use burn_backend::Tolerance;
50
51        let expected = into_data_sync::<R>(self.clone());
52        let actual = into_data_sync::<R>(other);
53        expected.assert_approx_eq::<f32>(&actual, Tolerance::permissive());
54    }
55}
56
57impl<R> core::fmt::Debug for CubeTensor<R>
58where
59    R: CubeRuntime,
60{
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        f.write_fmt(format_args!(
63            "CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}",
64            self.shape,
65            self.device,
66            self.strides,
67            self.dtype.name(),
68            R::name(&self.client),
69        ))
70    }
71}
72
73impl<R> Clone for CubeTensor<R>
74where
75    R: CubeRuntime,
76{
77    fn clone(&self) -> Self {
78        Self {
79            client: self.client.clone(),
80            handle: self.handle.clone(),
81            shape: self.shape.clone(),
82            device: self.device.clone(),
83            strides: self.strides.clone(),
84            dtype: self.dtype,
85            qparams: self.qparams.clone(),
86        }
87    }
88}
89
90impl<R: CubeRuntime> TensorMetadata for CubeTensor<R> {
91    fn dtype(&self) -> DType {
92        self.dtype
93    }
94
95    fn shape(&self) -> Shape {
96        self.shape.clone()
97    }
98
99    fn rank(&self) -> usize {
100        self.shape.num_dims()
101    }
102}
103
104impl<R: CubeRuntime> QTensorPrimitive for CubeTensor<R> {
105    fn scheme(&self) -> &QuantScheme {
106        if let DType::QFloat(scheme) = &self.dtype {
107            scheme
108        } else {
109            panic!(
110                "Quantization scheme is not valid for dtype {:?}",
111                self.dtype,
112            )
113        }
114    }
115}
116
117impl<R> CubeTensor<R>
118where
119    R: CubeRuntime,
120{
121    /// Create a new standard tensor
122    pub fn new(
123        client: ComputeClient<R>,
124        handle: Handle,
125        shape: Shape,
126        device: R::Device,
127        strides: Vec<usize>,
128        dtype: DType,
129    ) -> Self {
130        CubeTensor {
131            client,
132            handle,
133            shape,
134            device,
135            strides,
136            dtype,
137            qparams: None,
138        }
139    }
140
141    /// Create a new tensor with a contiguous memory layout.
142    pub fn new_contiguous(
143        client: ComputeClient<R>,
144        device: R::Device,
145        shape: Shape,
146        handle: Handle,
147        dtype: DType,
148    ) -> Self {
149        let ndims = shape.num_dims();
150        let mut strides = vec![0; ndims];
151        let mut current = 1;
152
153        shape
154            .dims
155            .iter()
156            .enumerate()
157            .rev()
158            .for_each(|(index, val)| {
159                strides[index] = current;
160                current *= val;
161            });
162
163        Self {
164            client,
165            handle,
166            shape,
167            strides,
168            device,
169            dtype,
170            qparams: None,
171        }
172    }
173
174    /// Change the context of the current tensor and return the newly transferred tensor.
175    pub fn to_client(&self, client: ComputeClient<R>, device: R::Device) -> Self {
176        let desc = self
177            .handle
178            .copy_descriptor(&self.shape.dims, &self.strides, self.elem_size());
179        let alloc = self.client.to_client_tensor(desc, &client);
180
181        Self {
182            client,
183            handle: alloc.handle,
184            shape: self.shape.clone(),
185            device,
186            strides: alloc.strides,
187            dtype: self.dtype,
188            qparams: self.qparams.clone(),
189        }
190    }
191
192    /// Return the reference to a tensor handle.
193    pub fn as_handle_ref(&self) -> TensorHandleRef<'_, R> {
194        TensorHandleRef {
195            handle: &self.handle,
196            strides: &self.strides,
197            shape: &self.shape.dims,
198            runtime: PhantomData,
199            elem_size: self.elem_size(),
200        }
201    }
202
203    /// Returns the element size of this tensor
204    pub fn elem_size(&self) -> usize {
205        self.dtype.size()
206    }
207
208    /// Return the reference to a tensor argument.
209    pub fn as_tensor_arg<'a>(&'a self, line_size: LineSize) -> TensorArg<'a, R> {
210        let size = self.dtype.size();
211        let handle: TensorHandleRef<'a, R> = self.as_handle_ref();
212
213        unsafe {
214            TensorArg::from_raw_parts_and_size(
215                handle.handle,
216                handle.strides,
217                handle.shape,
218                line_size,
219                size,
220            )
221        }
222    }
223
224    /// Return the reference to an array argument.
225    pub fn as_array_arg<E: CubeElement>(&self, line_size: LineSize) -> ArrayArg<'_, R> {
226        unsafe {
227            ArrayArg::from_raw_parts::<E>(
228                &self.handle,
229                self.handle.size() as usize / core::mem::size_of::<E>(),
230                line_size,
231            )
232        }
233    }
234
235    /// Return the `QuantScheme` if present
236    pub fn try_scheme(&self) -> Option<&QuantScheme> {
237        match &self.dtype {
238            DType::QFloat(scheme) => Some(scheme),
239            _ => None,
240        }
241    }
242
243    pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool {
244        if !self.handle.can_mut() || !self.is_contiguous_buffer() {
245            return false;
246        }
247        let ndims = self.shape.num_dims();
248
249        for i in 0..ndims {
250            let shape_lhs = self.shape[i];
251            let shape_rhs = rhs.shape[i];
252
253            // Output tensor will be different from the mutable tensor.
254            if shape_lhs < shape_rhs {
255                return false;
256            }
257        }
258
259        true
260    }
261
262    /// Copy the current tensor.
263    pub fn copy(&self) -> Self {
264        struct Copy;
265
266        #[cube]
267        impl<N: Numeric> NumericUnaryOp<N> for Copy {
268            type Options = ();
269
270            fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
271                input
272            }
273        }
274
275        impl NumericUnaryOpFamily for Copy {
276            type Options = ();
277            type Unary<N: Numeric> = Self;
278        }
279
280        let tensor = self.clone();
281        launch_unary_numeric::<R, Copy, _>(tensor, |_| ())
282    }
283
284    /// Check if the tensor is safe to mutate.
285    pub fn can_mut(&self) -> bool {
286        self.handle.can_mut()
287    }
288
289    /// Assert that both tensors are on the same device.
290    pub fn assert_is_on_same_device(&self, other: &Self) {
291        if self.device != other.device {
292            panic!(
293                "Both tensors should be on the same device {:?} != {:?}",
294                self.device, other.device
295            );
296        }
297    }
298
299    /// Check if the current tensor is contiguous.
300    ///
301    /// A tensor is contiguous if the elements are stored in memory
302    /// if the strides in non-increasing order and the
303    /// strides at position k is equal to the product of the shapes
304    /// at all positions greater than k. However, all axes with a shape of 1 are ignored.
305    pub fn is_contiguous(&self) -> bool {
306        is_contiguous(&self.shape.dims, &self.strides)
307    }
308
309    /// Check if the current tensor has a contiguous backing buffer (no overlap and no empty memory
310    /// regions within the shape).
311    pub fn is_contiguous_buffer(&self) -> bool {
312        self.shape.num_elements() * self.dtype.size() == self.handle.size() as usize
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319
320    #[test]
321    fn is_contiguous_non_increasing() {
322        assert!(is_contiguous(&[3, 1], &[1, 1]));
323    }
324
325    #[test]
326    fn is_contiguous_basic() {
327        assert!(is_contiguous(&[32, 32], &[32, 1]));
328    }
329
330    #[test]
331    fn is_contiguous_permuted() {
332        assert!(!is_contiguous(&[32, 32], &[1, 32]));
333    }
334
335    #[test]
336    fn is_contiguous_slice() {
337        assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));
338    }
339
340    #[test]
341    fn is_contiguous_4d_positive() {
342        assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));
343    }
344
345    #[test]
346    fn is_contiguous_4d_negative() {
347        assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));
348    }
349
350    /// Based on a bug encountered in interpolate_1d
351    #[test]
352    fn is_contiguous_4d_unit_shape() {
353        assert!(!is_contiguous(&[1, 1, 1, 9], &[72, 1, 72, 8]));
354    }
355}