burn_jit/tensor/
base.rs

1use crate::element::JitElement;
2use crate::kernel::{launch_unary_numeric, NumericUnaryOp, NumericUnaryOpFamily};
3use crate::JitRuntime;
4use burn_tensor::quantization::QTensorPrimitive;
5use burn_tensor::{DType, Shape, TensorMetadata};
6use cubecl::client::ComputeClient;
7use cubecl::frontend::Numeric;
8use cubecl::linalg::tensor::TensorHandle;
9use cubecl::prelude::{TensorHandleRef, *};
10use cubecl::server::Handle;
11use std::marker::PhantomData;
12
13/// The basic tensor primitive struct.
14#[derive(new)]
15pub struct JitTensor<R: JitRuntime> {
16    /// Compute client for the [runtime](JitRuntime).
17    pub client: ComputeClient<R::Server, R::Channel>,
18    /// The buffer where the data are stored.
19    pub handle: Handle,
20    /// The shape of the tensor.
21    pub shape: Shape,
22    /// The device of the tensor.
23    pub device: R::Device,
24    /// The strides of the tensor.
25    pub strides: Vec<usize>,
26    pub(crate) dtype: DType,
27}
28
29impl<R: JitRuntime, E: JitElement> From<JitTensor<R>> for TensorHandle<R, E> {
30    fn from(val: JitTensor<R>) -> Self {
31        TensorHandle::new(val.shape.dims.to_vec(), val.strides.to_vec(), val.handle)
32    }
33}
34
35impl<R> core::fmt::Debug for JitTensor<R>
36where
37    R: JitRuntime,
38{
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        f.write_fmt(format_args!(
41            "JitTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}",
42            self.shape,
43            self.device,
44            self.strides,
45            self.dtype.name(),
46            R::name(),
47        ))
48    }
49}
50
51impl<R> Clone for JitTensor<R>
52where
53    R: JitRuntime,
54{
55    fn clone(&self) -> Self {
56        Self {
57            client: self.client.clone(),
58            handle: self.handle.clone(),
59            shape: self.shape.clone(),
60            device: self.device.clone(),
61            strides: self.strides.clone(),
62            dtype: self.dtype,
63        }
64    }
65}
66
67impl<R: JitRuntime> TensorMetadata for JitTensor<R> {
68    fn dtype(&self) -> DType {
69        self.dtype
70    }
71
72    fn shape(&self) -> Shape {
73        self.shape.clone()
74    }
75}
76
77impl<R: JitRuntime> QTensorPrimitive for JitTensor<R> {
78    fn scheme(&self) -> &burn_tensor::quantization::QuantizationScheme {
79        if let DType::QFloat(scheme) = &self.dtype {
80            scheme
81        } else {
82            panic!(
83                "Quantization scheme is not valid for dtype {:?}",
84                self.dtype,
85            )
86        }
87    }
88}
89
90/// Macro to execute a kernel/operation for a given element type.
91///
92/// # Panics
93/// Since there is no automatic type cast at this time, binary operations for different
94/// floating point precision data types will panic with a data type mismatch.
95#[macro_export]
96macro_rules! execute_with_dtype {
97    (float($dtype:expr), $element:ident, $op:expr) => {{
98        match $dtype {
99            burn_tensor::DType::F64 => {
100                type $element = f64;
101                $op
102            }
103            burn_tensor::DType::F32 => {
104                type $element = f32;
105                $op
106            }
107            burn_tensor::DType::F16 => {
108                type $element = half::f16;
109                $op
110            }
111            burn_tensor::DType::BF16 => {
112                type $element = half::bf16;
113                $op
114            }
115            _ => unimplemented!("Unsupported dtype"),
116        }
117    }};
118
119    (float($lhs_dtype:expr, $rhs_dtype:expr), $element:ident, $op:expr) => {{
120        // NOTE: might be better for floating point binary operations to return a Result instead?
121        if $lhs_dtype != $rhs_dtype {
122            panic!(
123                "Data type mismatch (lhs: {:?}, rhs: {:?})",
124                $lhs_dtype, $rhs_dtype
125            );
126        }
127        execute_with_dtype!(float($lhs_dtype), $element, $op)
128    }};
129    ($dtype:expr, $element:ident, $op:expr) => {{
130        match $dtype {
131            burn_tensor::DType::F64 => {
132                type $element = f64;
133                $op
134            }
135            burn_tensor::DType::F32 => {
136                type $element = f32;
137                $op
138            }
139            burn_tensor::DType::F16 => {
140                type $element = half::f16;
141                $op
142            }
143            burn_tensor::DType::BF16 => {
144                type $element = half::bf16;
145                $op
146            }
147            burn_tensor::DType::U64 => {
148                type $element = u64;
149                $op
150            }
151            burn_tensor::DType::U32 => {
152                type $element = u32;
153                $op
154            }
155            burn_tensor::DType::U16 => {
156                type $element = u16;
157                $op
158            }
159            burn_tensor::DType::U8 => {
160                type $element = u8;
161                $op
162            }
163            burn_tensor::DType::I64 => {
164                type $element = i64;
165                $op
166            }
167            burn_tensor::DType::I32 => {
168                type $element = i32;
169                $op
170            }
171            burn_tensor::DType::I16 => {
172                type $element = i16;
173                $op
174            }
175            burn_tensor::DType::I8 => {
176                type $element = i8;
177                $op
178            }
179            // NOTE: bool and qfloat dtypes are actually represented as u32/u8
180            // burn_tensor::DType::Bool => {
181            //     type $element = u32/u8;
182            //     $op
183            // }
184            // burn_tensor::DType::QFloat(_) => {
185            //     type $element = u32;
186            //     $op
187            // }
188            _ => unimplemented!("Unsupported dtype"),
189        }
190    }};
191}
192
193impl<R> JitTensor<R>
194where
195    R: JitRuntime,
196{
197    /// Create a new tensor with a contiguous memory layout.
198    pub fn new_contiguous(
199        client: ComputeClient<R::Server, R::Channel>,
200        device: R::Device,
201        shape: Shape,
202        handle: Handle,
203        dtype: DType,
204    ) -> Self {
205        let ndims = shape.num_dims();
206        let mut strides = vec![0; ndims];
207
208        let mut current = 1;
209        shape
210            .dims
211            .iter()
212            .enumerate()
213            .rev()
214            .for_each(|(index, val)| {
215                strides[index] = current;
216                current *= val;
217            });
218
219        Self {
220            client,
221            handle,
222            shape,
223            strides,
224            device,
225            dtype,
226        }
227    }
228
229    /// Change the context of the current tensor and return the newly transferred tensor.
230    pub fn to_client(
231        &self,
232        client: ComputeClient<R::Server, R::Channel>,
233        device: R::Device,
234    ) -> Self {
235        let bytes = burn_common::reader::try_read_sync(
236            self.client.read_one_async(self.handle.clone().binding()),
237        )
238        .expect("Can only change client synchronously");
239        let handle = client.create(&bytes);
240
241        Self {
242            client,
243            handle,
244            shape: self.shape.clone(),
245            strides: self.strides.clone(),
246            device,
247            dtype: self.dtype,
248        }
249    }
250
251    /// Return the reference to a tensor handle.
252    pub fn as_handle_ref(&self) -> TensorHandleRef<'_, R> {
253        TensorHandleRef {
254            handle: &self.handle,
255            strides: &self.strides,
256            shape: &self.shape.dims,
257            runtime: PhantomData,
258            elem_size: self.elem_size(),
259        }
260    }
261
262    fn elem_size(&self) -> usize {
263        if let DType::QFloat(_) = self.dtype {
264            // Encoded as u32
265            core::mem::size_of::<u32>()
266        } else {
267            self.dtype.size()
268        }
269    }
270
271    /// Return the reference to a tensor argument.
272    pub fn as_tensor_arg<'a, E: JitElement>(&'a self, vectorisation: u8) -> TensorArg<'a, R> {
273        let handle: TensorHandleRef<'a, R> = self.as_handle_ref();
274
275        unsafe {
276            TensorArg::from_raw_parts::<E>(
277                handle.handle,
278                handle.strides,
279                handle.shape,
280                vectorisation,
281            )
282        }
283    }
284
285    /// Return the reference to an array argument.
286    pub fn as_array_arg<E: JitElement>(&self, vectorisation: u8) -> ArrayArg<'_, R> {
287        unsafe {
288            ArrayArg::from_raw_parts::<E>(
289                &self.handle,
290                self.handle.size() as usize / core::mem::size_of::<E>(),
291                vectorisation,
292            )
293        }
294    }
295
296    pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool {
297        if !self.handle.can_mut() || !self.is_contiguous_buffer() {
298            return false;
299        }
300        let ndims = self.shape.num_dims();
301
302        for i in 0..ndims {
303            let shape_lhs = self.shape.dims[i];
304            let shape_rhs = rhs.shape.dims[i];
305
306            // Output tensor will be different from the mutable tensor.
307            if shape_lhs < shape_rhs {
308                return false;
309            }
310        }
311
312        true
313    }
314
315    /// Copy the current tensor.
316    pub fn copy(&self) -> Self {
317        struct Copy;
318
319        #[cube]
320        impl<N: Numeric> NumericUnaryOp<N> for Copy {
321            type Options = ();
322
323            fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
324                input
325            }
326        }
327
328        impl NumericUnaryOpFamily for Copy {
329            type Options<N: Numeric> = ();
330            type Unary<N: Numeric> = Self;
331        }
332
333        let tensor = self.clone();
334
335        execute_with_dtype!(
336            tensor.dtype,
337            E,
338            launch_unary_numeric::<R, E, Copy, _>(tensor, |_| ())
339        )
340    }
341
342    /// Check if the tensor is safe to mutate.
343    pub fn can_mut(&self) -> bool {
344        self.handle.can_mut()
345    }
346
347    /// Assert that both tensors are on the same device.
348    pub fn assert_is_on_same_device(&self, other: &Self) {
349        if self.device != other.device {
350            panic!(
351                "Both tensors should be on the same device {:?} != {:?}",
352                self.device, other.device
353            );
354        }
355    }
356
357    /// Check if the current tensor is contiguous.
358    pub fn is_contiguous(&self) -> bool {
359        is_contiguous(&self.shape.dims, &self.strides)
360    }
361
362    /// Check if the current tensor has a contiguous backing buffer (no overlap and no empty memory
363    /// regions within the shape).
364    pub fn is_contiguous_buffer(&self) -> bool {
365        self.shape.num_elements() * self.dtype.size() == self.handle.size() as usize
366    }
367}
368
369pub(crate) fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
370    if shape.is_empty() {
371        return true;
372    }
373
374    if shape.len() == 1 {
375        return strides[0] == 1;
376    }
377
378    let mut prev_stride = 1;
379    let mut current_num_elems_shape = 1;
380
381    for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() {
382        if i > 0 {
383            if current_num_elems_shape != *stride {
384                return false;
385            }
386
387            if prev_stride >= *stride {
388                return false;
389            }
390        }
391
392        current_num_elems_shape *= shape;
393        prev_stride = *stride;
394    }
395
396    true
397}
398
399#[cfg(test)]
400mod tests {
401    use crate::tensor::base::is_contiguous;
402
403    #[test]
404    fn is_contiguous_basic() {
405        assert!(is_contiguous(&[32, 32], &[32, 1]));
406    }
407
408    #[test]
409    fn is_contiguous_permuted() {
410        assert!(!is_contiguous(&[32, 32], &[1, 32]));
411    }
412
413    #[test]
414    fn is_contiguous_slice() {
415        assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));
416    }
417
418    #[test]
419    fn is_contiguous_4d_positive() {
420        assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));
421    }
422
423    #[test]
424    fn is_contiguous_4d_negative() {
425        assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));
426    }
427}