burn-cubecl 0.21.0-pre.3

Generic backend that can be compiled just-in-time to any shader language target
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
use crate::CubeRuntime;
use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric};
use burn_backend::quantization::QuantScheme;
use burn_backend::{DType, QTensorPrimitive, Shape, TensorMetadata};
use burn_std::{Metadata, strides, tensor::is_contiguous};
use cubecl::server::Handle;
use cubecl::std::tensor::TensorHandle;
use cubecl::{client::ComputeClient, std::tensor::layout::linear::LinearViewLaunch};
use cubecl::{frontend::Numeric, std::tensor::layout::linear::LinearViewLayoutLaunch};
use cubecl::{
    prelude::{TensorBinding, *},
    std::tensor::layout::linear::LinearViewLayout,
};
use std::marker::PhantomData;

use super::QParams;

/// The basic tensor primitive struct.
pub struct CubeTensor<R: CubeRuntime> {
    /// Compute client for the [runtime](CubeRuntime).
    pub client: ComputeClient<R>,
    /// The buffer where the data are stored.
    pub handle: Handle,
    /// The metadata of the tensor.
    pub meta: Box<Metadata>,
    /// The device of the tensor.
    pub device: R::Device,
    /// The datatype of the tensor.
    pub dtype: DType,
    /// Runtime quantization parameters, if applicable
    pub qparams: Option<QParams>,
}

impl<R: CubeRuntime> From<CubeTensor<R>> for TensorHandle<R> {
    fn from(val: CubeTensor<R>) -> Self {
        TensorHandle::new(
            val.handle.clone(),
            val.meta.shape().clone(),
            val.meta.strides().clone(),
            val.dtype,
        )
    }
}

impl<R: CubeRuntime> cubecl::tune::AutotuneOutput for CubeTensor<R> {
    #[cfg(feature = "autotune-checks")]
    fn check_equivalence(&self, other: Self) {
        use crate::ops::into_data_sync;
        use burn_backend::Tolerance;

        let expected = into_data_sync::<R>(self.clone());
        let actual = into_data_sync::<R>(other);
        expected.assert_approx_eq::<f32>(&actual, Tolerance::permissive());
    }
}

// TODO: Needed to cleanup leaves tensor.
//
// Maybe not needed when fusion is activated, since we have a detector there.
// We could rely on basic GC strategy when not using fusion.
//
// impl<R: CubeRuntime> Drop for CubeTensor<R> {
//     fn drop(&mut self) {
//         todo!()
//     }
// }

impl<R> core::fmt::Debug for CubeTensor<R>
where
    R: CubeRuntime,
{
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.write_fmt(format_args!(
            "CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}",
            self.meta.shape(),
            self.device,
            self.meta.strides(),
            self.dtype.name(),
            R::name(&self.client),
        ))
    }
}

impl<R> Clone for CubeTensor<R>
where
    R: CubeRuntime,
{
    fn clone(&self) -> Self {
        Self {
            client: self.client.clone(),
            handle: self.handle.clone(),
            meta: self.meta.clone(),
            device: self.device.clone(),
            dtype: self.dtype,
            qparams: self.qparams.clone(),
        }
    }
}

impl<R: CubeRuntime> TensorMetadata for CubeTensor<R> {
    fn dtype(&self) -> DType {
        self.dtype
    }

    fn shape(&self) -> Shape {
        self.meta.shape().clone()
    }

    fn rank(&self) -> usize {
        self.meta.rank()
    }
}

impl<R: CubeRuntime> QTensorPrimitive for CubeTensor<R> {
    fn scheme(&self) -> &QuantScheme {
        if let DType::QFloat(scheme) = &self.dtype {
            scheme
        } else {
            panic!(
                "Quantization scheme is not valid for dtype {:?}",
                self.dtype,
            )
        }
    }
}

impl<R> CubeTensor<R>
where
    R: CubeRuntime,
{
    /// Create a new standard tensor
    pub fn new(
        client: ComputeClient<R>,
        handle: Handle,
        metadata: Metadata,
        device: R::Device,
        dtype: DType,
    ) -> Self {
        CubeTensor {
            client,
            handle,
            meta: Box::new(metadata),
            device,
            dtype,
            qparams: None,
        }
    }

    /// Create a new tensor with a contiguous memory layout.
    pub fn new_contiguous(
        client: ComputeClient<R>,
        device: R::Device,
        shape: Shape,
        handle: Handle,
        dtype: DType,
    ) -> Self {
        let ndims = shape.num_dims();
        let mut strides = strides![0; ndims];
        let mut current = 1;

        shape.iter().enumerate().rev().for_each(|(index, val)| {
            strides[index] = current;
            current *= val;
        });

        Self {
            client,
            handle,
            meta: Box::new(Metadata::new(shape, strides)),
            device,
            dtype,
            qparams: None,
        }
    }

    /// Change the context of the current tensor and return the newly transferred tensor.
    pub fn to_client(&self, client: ComputeClient<R>, device: R::Device) -> Self {
        let desc = self.handle.clone().copy_descriptor(
            self.meta.shape().clone(),
            self.meta.strides().clone(),
            self.elem_size(),
        );
        let handle = self.client.to_client_tensor(desc, &client);

        Self {
            client,
            handle,
            meta: Box::new(Metadata::new(self.shape(), self.meta.strides().clone())),
            device,
            dtype: self.dtype,
            qparams: self.qparams.clone(),
        }
    }

    /// Return the reference to a tensor handle.
    pub fn binding(self) -> TensorBinding<R> {
        TensorBinding {
            handle: self.handle.binding(),
            strides: self.meta.strides,
            shape: self.meta.shape,
            runtime: PhantomData,
        }
    }

    /// Returns the element size of this tensor
    pub fn elem_size(&self) -> usize {
        self.dtype.size()
    }

    /// Return the reference to a tensor argument.
    pub fn into_tensor_arg(self) -> TensorArg<R> {
        self.binding().into_tensor_arg()
    }

    /// Return the reference to an array argument.
    pub fn into_array_arg(self) -> ArrayArg<R> {
        self.into_tensor_arg().into_array_arg()
    }

    /// Returns a reference to the aliased tensor argument.
    pub fn as_tensor_alias(&self, input_pos: usize) -> TensorArg<R> {
        TensorArg::Alias {
            input_pos,
            strides: self.meta.strides().clone(),
            shape: self.meta.shape().clone(),
        }
    }

    /// Return a linear view of this tensor.
    pub fn into_linear_view(self) -> LinearViewLaunch<R> {
        let layout = LinearViewLayoutLaunch::new();
        let buffer = self.into_tensor_arg();
        LinearViewLaunch::new_tensor::<LinearViewLayout>(buffer, layout)
    }

    /// Return an aliased linear view of this tensor
    pub fn as_linear_view_alias(&self, input_pos: usize) -> LinearViewLaunch<R> {
        let layout = LinearViewLayoutLaunch::new();
        let buffer = self.as_tensor_alias(input_pos);
        LinearViewLaunch::new_tensor::<LinearViewLayout>(buffer, layout)
    }

    /// Return a linear view broadcast to the reference tensor's shape
    pub fn into_linear_view_like(self, reference: &Self) -> LinearViewLaunch<R> {
        let layout = LinearViewLayoutLaunch::from_reference_shape(reference.shape());
        let buffer = self.into_tensor_arg();
        LinearViewLaunch::new_tensor::<LinearViewLayout>(buffer, layout)
    }

    /// Returns the address type required to index this tensor
    pub fn required_address_type(&self) -> AddressType {
        match self.try_scheme() {
            Some(scheme) => {
                let len = self.handle.size() as usize * 8 / scheme.size_bits_value();
                AddressType::from_len(len)
            }
            None => AddressType::from_len(self.handle.size() as usize / self.dtype.size()),
        }
    }

    /// Return the `QuantScheme` if present
    pub fn try_scheme(&self) -> Option<&QuantScheme> {
        match &self.dtype {
            DType::QFloat(scheme) => Some(scheme),
            _ => None,
        }
    }

    pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool {
        if !self.handle.can_mut() || !self.is_nonoverlapping() {
            return false;
        }
        let ndims = self.meta.num_dims();

        for i in 0..ndims {
            let shape_lhs = self.meta.shape()[i];
            let shape_rhs = rhs.meta.shape()[i];

            // Output tensor will be different from the mutable tensor.
            if shape_lhs < shape_rhs {
                return false;
            }
        }

        true
    }

    /// Copy the current tensor.
    pub fn copy(&self) -> Self {
        struct Copy;

        #[cube]
        impl<T: Numeric, N: Size> NumericUnaryOp<T, N> for Copy {
            type Options = ();

            fn execute(input: Vector<T, N>, _options: &Self::Options) -> Vector<T, N> {
                input
            }
        }

        impl NumericUnaryOpFamily for Copy {
            type Options = ();
            type Unary<T: Numeric, N: Size> = Self;
        }

        let tensor = self.clone();
        launch_unary_numeric::<R, Copy, _>(tensor, |_| ())
    }

    /// Check if the tensor is safe to mutate.
    pub fn can_mut(&self) -> bool {
        self.handle.can_mut()
    }

    /// Assert that both tensors are on the same device.
    pub fn assert_is_on_same_device(&self, other: &Self) {
        if self.device != other.device {
            panic!(
                "Both tensors should be on the same device {:?} != {:?}",
                self.device, other.device
            );
        }
    }

    /// Check if the current tensor is contiguous.
    ///
    /// A tensor is contiguous if the elements are stored in memory
    /// if the strides in non-increasing order and the
    /// strides at position k is equal to the product of the shapes
    /// at all positions greater than k. However, all axes with a shape of 1 are ignored.
    pub fn is_contiguous(&self) -> bool {
        is_contiguous(self.meta.shape(), self.meta.strides())
    }

    /// Check if the current tensor has a contiguous backing buffer (no overlap and no empty memory
    /// regions within the shape).
    pub fn is_contiguous_buffer(&self) -> bool {
        self.meta.shape().num_elements() * self.dtype.size() == self.handle.size() as usize
    }

    /// Checks if the tensor is non-overlapping (can be safely written to).
    pub fn is_nonoverlapping(&self) -> bool {
        let shape = self.meta.shape();
        let strides = self.meta.strides();

        if strides.contains(&0) {
            return false;
        }
        let rank = self.rank();
        if rank > 1 {
            let mut dims = shape.iter().zip(strides.iter()).collect::<Vec<_>>();
            dims.sort_by_key(|(_, stride)| **stride);

            let mut max_offset = 0;
            for (shape, stride) in dims.into_iter() {
                if *stride <= max_offset && *shape != 1 {
                    return false;
                }

                max_offset += (*shape - 1) * *stride;
            }
        }
        true
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn is_contiguous_non_increasing() {
        assert!(is_contiguous(&[3, 1], &[1, 1]));
    }

    #[test]
    fn is_contiguous_basic() {
        assert!(is_contiguous(&[32, 32], &[32, 1]));
    }

    #[test]
    fn is_contiguous_permuted() {
        assert!(!is_contiguous(&[32, 32], &[1, 32]));
    }

    #[test]
    fn is_contiguous_slice() {
        assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));
    }

    #[test]
    fn is_contiguous_4d_positive() {
        assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));
    }

    #[test]
    fn is_contiguous_4d_negative() {
        assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));
    }

    /// Based on a bug encountered in interpolate_1d
    #[test]
    fn is_contiguous_4d_unit_shape() {
        assert!(!is_contiguous(&[1, 1, 1, 9], &[72, 1, 72, 8]));
    }
}