1use crate::CubeRuntime;
2use crate::element::CubeElement;
3use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric};
4use burn_common::tensor::is_contiguous;
5use burn_tensor::quantization::QTensorPrimitive;
6use burn_tensor::{DType, Shape, TensorMetadata};
7use cubecl::client::ComputeClient;
8use cubecl::frontend::Numeric;
9use cubecl::linalg::tensor::TensorHandle;
10use cubecl::prelude::{TensorHandleRef, *};
11use cubecl::server::Handle;
12use std::marker::PhantomData;
13
14#[derive(new)]
16pub struct CubeTensor<R: CubeRuntime> {
17 pub client: ComputeClient<R::Server, R::Channel>,
19 pub handle: Handle,
21 pub shape: Shape,
23 pub device: R::Device,
25 pub strides: Vec<usize>,
27 pub dtype: DType,
29}
30
31impl<R: CubeRuntime, E: CubeElement> From<CubeTensor<R>> for TensorHandle<R, E> {
32 fn from(val: CubeTensor<R>) -> Self {
33 TensorHandle::new(val.handle, val.shape.dims.to_vec(), val.strides.to_vec())
34 }
35}
36
37impl<R: CubeRuntime> cubecl::tune::AutotuneOutput for CubeTensor<R> {
38 #[cfg(feature = "export_tests")]
39 fn check_equivalence(&self, other: Self) {
40 use burn_tensor::Tolerance;
41
42 use crate::ops::into_data_sync;
43
44 match self.dtype {
45 DType::F64 => {
46 let expected = into_data_sync::<R, f64>(self.clone());
47 let actual = into_data_sync::<R, f64>(other);
48 expected.assert_approx_eq::<f64>(&actual, Tolerance::rel_abs(1e-2, 1e-3));
49 }
50 DType::F32 | DType::Flex32 => {
51 let expected = into_data_sync::<R, f32>(self.clone());
52 let actual = into_data_sync::<R, f32>(other);
53 expected.assert_approx_eq::<f32>(&actual, Tolerance::rel_abs(1e-2, 1e-3));
54 }
55 DType::F16 => {
56 let expected = into_data_sync::<R, half::f16>(self.clone());
57 let actual = into_data_sync::<R, half::f16>(other);
58 expected.assert_approx_eq::<half::f16>(&actual, Tolerance::rel_abs(1e-2, 4e-3));
59 }
60 DType::BF16 => {
61 let expected = into_data_sync::<R, half::bf16>(self.clone());
62 let actual = into_data_sync::<R, half::bf16>(other);
63 expected.assert_approx_eq::<half::bf16>(&actual, Tolerance::rel_abs(1e-2, 1e-3));
64 }
65 DType::I64 => {
66 let expected = into_data_sync::<R, i64>(self.clone());
67 let actual = into_data_sync::<R, i64>(other);
68 expected.assert_eq(&actual, true);
69 }
70 DType::I32 => {
71 let expected = into_data_sync::<R, i32>(self.clone());
72 let actual = into_data_sync::<R, i32>(other);
73 expected.assert_eq(&actual, true);
74 }
75 DType::I16 => {
76 let expected = into_data_sync::<R, i16>(self.clone());
77 let actual = into_data_sync::<R, i16>(other);
78 expected.assert_eq(&actual, true);
79 }
80 DType::I8 => {
81 let expected = into_data_sync::<R, i8>(self.clone());
82 let actual = into_data_sync::<R, i8>(other);
83 expected.assert_eq(&actual, true);
84 }
85 DType::U64 => {
86 let expected = into_data_sync::<R, u64>(self.clone());
87 let actual = into_data_sync::<R, u64>(other);
88 expected.assert_eq(&actual, true);
89 }
90 DType::U32 => {
91 let expected = into_data_sync::<R, u32>(self.clone());
92 let actual = into_data_sync::<R, u32>(other);
93 expected.assert_eq(&actual, true);
94 }
95 DType::U16 => {
96 let expected = into_data_sync::<R, u16>(self.clone());
97 let actual = into_data_sync::<R, u16>(other);
98 expected.assert_eq(&actual, true);
99 }
100 DType::U8 => {
101 let expected = into_data_sync::<R, u8>(self.clone());
102 let actual = into_data_sync::<R, u8>(other);
103 expected.assert_eq(&actual, true);
104 }
105 DType::Bool => (),
106 DType::QFloat(..) => (),
107 }
108 }
109}
110
111impl<R> core::fmt::Debug for CubeTensor<R>
112where
113 R: CubeRuntime,
114{
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 f.write_fmt(format_args!(
117 "CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}",
118 self.shape,
119 self.device,
120 self.strides,
121 self.dtype.name(),
122 R::name(&self.client),
123 ))
124 }
125}
126
127impl<R> Clone for CubeTensor<R>
128where
129 R: CubeRuntime,
130{
131 fn clone(&self) -> Self {
132 Self {
133 client: self.client.clone(),
134 handle: self.handle.clone(),
135 shape: self.shape.clone(),
136 device: self.device.clone(),
137 strides: self.strides.clone(),
138 dtype: self.dtype,
139 }
140 }
141}
142
143impl<R: CubeRuntime> TensorMetadata for CubeTensor<R> {
144 fn dtype(&self) -> DType {
145 self.dtype
146 }
147
148 fn shape(&self) -> Shape {
149 self.shape.clone()
150 }
151}
152
153impl<R: CubeRuntime> QTensorPrimitive for CubeTensor<R> {
154 fn scheme(&self) -> &burn_tensor::quantization::QuantizationScheme {
155 if let DType::QFloat(scheme) = &self.dtype {
156 scheme
157 } else {
158 panic!(
159 "Quantization scheme is not valid for dtype {:?}",
160 self.dtype,
161 )
162 }
163 }
164}
165
166#[macro_export]
172macro_rules! execute_with_dtype {
173 (float($dtype:expr), $element:ident, $op:expr) => {{
174 match $dtype {
175 burn_tensor::DType::F64 => {
176 type $element = f64;
177 $op
178 }
179 burn_tensor::DType::F32 => {
180 type $element = f32;
181 $op
182 }
183 burn_tensor::DType::Flex32 => {
184 type $element = cubecl::flex32;
185 $op
186 }
187
188 burn_tensor::DType::F16 => {
189 type $element = half::f16;
190 $op
191 }
192 burn_tensor::DType::BF16 => {
193 type $element = half::bf16;
194 $op
195 }
196 _ => unimplemented!("Unsupported dtype {:?}", $dtype),
197 }
198 }};
199
200 (float($lhs_dtype:expr, $rhs_dtype:expr), $element:ident, $op:expr) => {{
201 if $lhs_dtype != $rhs_dtype {
203 panic!(
204 "Data type mismatch (lhs: {:?}, rhs: {:?})",
205 $lhs_dtype, $rhs_dtype
206 );
207 }
208 execute_with_dtype!(float($lhs_dtype), $element, $op)
209 }};
210 ($dtype:expr, $element:ident, $op:expr) => {{
211 match $dtype {
212 burn_tensor::DType::F64 => {
213 type $element = f64;
214 $op
215 }
216 burn_tensor::DType::F32 => {
217 type $element = f32;
218 $op
219 }
220 burn_tensor::DType::Flex32 => {
221 type $element = cubecl::flex32;
222 $op
223 }
224 burn_tensor::DType::F16 => {
225 type $element = half::f16;
226 $op
227 }
228 burn_tensor::DType::BF16 => {
229 type $element = half::bf16;
230 $op
231 }
232 burn_tensor::DType::U64 => {
233 type $element = u64;
234 $op
235 }
236 burn_tensor::DType::U32 => {
237 type $element = u32;
238 $op
239 }
240 burn_tensor::DType::U16 => {
241 type $element = u16;
242 $op
243 }
244 burn_tensor::DType::U8 => {
245 type $element = u8;
246 $op
247 }
248 burn_tensor::DType::I64 => {
249 type $element = i64;
250 $op
251 }
252 burn_tensor::DType::I32 => {
253 type $element = i32;
254 $op
255 }
256 burn_tensor::DType::I16 => {
257 type $element = i16;
258 $op
259 }
260 burn_tensor::DType::I8 => {
261 type $element = i8;
262 $op
263 }
264 burn_tensor::DType::QFloat(_) => {
270 type $element = u32;
271 $op
272 }
273 _ => unimplemented!("Unsupported dtype {:?}", $dtype),
274 }
275 }};
276}
277
278impl<R> CubeTensor<R>
279where
280 R: CubeRuntime,
281{
282 pub fn new_contiguous(
284 client: ComputeClient<R::Server, R::Channel>,
285 device: R::Device,
286 shape: Shape,
287 handle: Handle,
288 dtype: DType,
289 ) -> Self {
290 let ndims = shape.num_dims();
291 let mut strides = vec![0; ndims];
292 let mut current = 1;
293
294 shape
295 .dims
296 .iter()
297 .enumerate()
298 .rev()
299 .for_each(|(index, val)| {
300 strides[index] = current;
301 current *= val;
302 });
303
304 Self {
305 client,
306 handle,
307 shape,
308 strides,
309 device,
310 dtype,
311 }
312 }
313
314 pub fn to_client(
316 &self,
317 client: ComputeClient<R::Server, R::Channel>,
318 device: R::Device,
319 ) -> Self {
320 let bytes = burn_common::reader::try_read_sync(
321 self.client.read_one_async(self.handle.clone().binding()),
322 )
323 .expect("Can only change client synchronously");
324 let handle = client.create(&bytes);
325
326 Self {
327 client,
328 handle,
329 shape: self.shape.clone(),
330 strides: self.strides.clone(),
331 device,
332 dtype: self.dtype,
333 }
334 }
335
336 pub fn as_handle_ref(&self) -> TensorHandleRef<'_, R> {
338 TensorHandleRef {
339 handle: &self.handle,
340 strides: &self.strides,
341 shape: &self.shape.dims,
342 runtime: PhantomData,
343 elem_size: self.elem_size(),
344 }
345 }
346
347 fn elem_size(&self) -> usize {
348 if let DType::QFloat(_) = self.dtype {
349 core::mem::size_of::<u32>()
351 } else {
352 self.dtype.size()
353 }
354 }
355
356 pub fn as_tensor_arg<'a, E: CubeElement>(&'a self, vectorisation: u8) -> TensorArg<'a, R> {
358 let handle: TensorHandleRef<'a, R> = self.as_handle_ref();
359
360 unsafe {
361 TensorArg::from_raw_parts::<E>(
362 handle.handle,
363 handle.strides,
364 handle.shape,
365 vectorisation,
366 )
367 }
368 }
369
370 pub fn as_array_arg<E: CubeElement>(&self, vectorisation: u8) -> ArrayArg<'_, R> {
372 unsafe {
373 ArrayArg::from_raw_parts::<E>(
374 &self.handle,
375 self.handle.size() as usize / core::mem::size_of::<E>(),
376 vectorisation,
377 )
378 }
379 }
380
381 pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool {
382 if !self.handle.can_mut() || !self.is_contiguous_buffer() {
383 return false;
384 }
385 let ndims = self.shape.num_dims();
386
387 for i in 0..ndims {
388 let shape_lhs = self.shape.dims[i];
389 let shape_rhs = rhs.shape.dims[i];
390
391 if shape_lhs < shape_rhs {
393 return false;
394 }
395 }
396
397 true
398 }
399
400 pub fn copy(&self) -> Self {
402 struct Copy;
403
404 #[cube]
405 impl<N: Numeric> NumericUnaryOp<N> for Copy {
406 type Options = ();
407
408 fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
409 input
410 }
411 }
412
413 impl NumericUnaryOpFamily for Copy {
414 type Options<N: Numeric> = ();
415 type Unary<N: Numeric> = Self;
416 }
417
418 let tensor = self.clone();
419
420 execute_with_dtype!(
421 tensor.dtype,
422 E,
423 launch_unary_numeric::<R, E, Copy, _>(tensor, |_| ())
424 )
425 }
426
427 pub fn can_mut(&self) -> bool {
429 self.handle.can_mut()
430 }
431
432 pub fn assert_is_on_same_device(&self, other: &Self) {
434 if self.device != other.device {
435 panic!(
436 "Both tensors should be on the same device {:?} != {:?}",
437 self.device, other.device
438 );
439 }
440 }
441
442 pub fn is_contiguous(&self) -> bool {
449 is_contiguous(&self.shape.dims, &self.strides)
450 }
451
452 pub fn is_contiguous_buffer(&self) -> bool {
455 self.shape.num_elements() * self.dtype.size() == self.handle.size() as usize
456 }
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462
463 #[test]
464 fn is_contiguous_non_increasing() {
465 assert!(is_contiguous(&[3, 1], &[1, 1]));
466 }
467
468 #[test]
469 fn is_contiguous_basic() {
470 assert!(is_contiguous(&[32, 32], &[32, 1]));
471 }
472
473 #[test]
474 fn is_contiguous_permuted() {
475 assert!(!is_contiguous(&[32, 32], &[1, 32]));
476 }
477
478 #[test]
479 fn is_contiguous_slice() {
480 assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));
481 }
482
483 #[test]
484 fn is_contiguous_4d_positive() {
485 assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));
486 }
487
488 #[test]
489 fn is_contiguous_4d_negative() {
490 assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));
491 }
492}