1use crate::CubeRuntime;
2use crate::element::CubeElement;
3use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric};
4use burn_tensor::quantization::QTensorPrimitive;
5use burn_tensor::{DType, Shape, TensorMetadata};
6use cubecl::client::ComputeClient;
7use cubecl::frontend::Numeric;
8use cubecl::linalg::tensor::TensorHandle;
9use cubecl::prelude::{TensorHandleRef, *};
10use cubecl::server::Handle;
11use std::marker::PhantomData;
12
13#[derive(new)]
15pub struct CubeTensor<R: CubeRuntime> {
16 pub client: ComputeClient<R::Server, R::Channel>,
18 pub handle: Handle,
20 pub shape: Shape,
22 pub device: R::Device,
24 pub strides: Vec<usize>,
26 pub dtype: DType,
28}
29
30impl<R: CubeRuntime, E: CubeElement> From<CubeTensor<R>> for TensorHandle<R, E> {
31 fn from(val: CubeTensor<R>) -> Self {
32 TensorHandle::new(val.handle, val.shape.dims.to_vec(), val.strides.to_vec())
33 }
34}
35
36impl<R: CubeRuntime> cubecl::tune::AutotuneOutput for CubeTensor<R> {
37 #[cfg(feature = "export_tests")]
38 fn check_equivalence(&self, other: Self) {
39 use burn_tensor::Tolerance;
40
41 use crate::ops::into_data_sync;
42
43 match self.dtype {
44 DType::F64 => {
45 let expected = into_data_sync::<R, f64>(self.clone());
46 let actual = into_data_sync::<R, f64>(other);
47 expected.assert_approx_eq::<f64>(&actual, Tolerance::rel_abs(1e-2, 1e-3));
48 }
49 DType::F32 | DType::Flex32 => {
50 let expected = into_data_sync::<R, f32>(self.clone());
51 let actual = into_data_sync::<R, f32>(other);
52 expected.assert_approx_eq::<f32>(&actual, Tolerance::rel_abs(1e-2, 1e-3));
53 }
54 DType::F16 => {
55 let expected = into_data_sync::<R, half::f16>(self.clone());
56 let actual = into_data_sync::<R, half::f16>(other);
57 expected.assert_approx_eq::<half::f16>(&actual, Tolerance::rel_abs(1e-2, 4e-3));
58 }
59 DType::BF16 => {
60 let expected = into_data_sync::<R, half::bf16>(self.clone());
61 let actual = into_data_sync::<R, half::bf16>(other);
62 expected.assert_approx_eq::<half::bf16>(&actual, Tolerance::rel_abs(1e-2, 1e-3));
63 }
64 DType::I64 => {
65 let expected = into_data_sync::<R, i64>(self.clone());
66 let actual = into_data_sync::<R, i64>(other);
67 expected.assert_eq(&actual, true);
68 }
69 DType::I32 => {
70 let expected = into_data_sync::<R, i32>(self.clone());
71 let actual = into_data_sync::<R, i32>(other);
72 expected.assert_eq(&actual, true);
73 }
74 DType::I16 => {
75 let expected = into_data_sync::<R, i16>(self.clone());
76 let actual = into_data_sync::<R, i16>(other);
77 expected.assert_eq(&actual, true);
78 }
79 DType::I8 => {
80 let expected = into_data_sync::<R, i8>(self.clone());
81 let actual = into_data_sync::<R, i8>(other);
82 expected.assert_eq(&actual, true);
83 }
84 DType::U64 => {
85 let expected = into_data_sync::<R, u64>(self.clone());
86 let actual = into_data_sync::<R, u64>(other);
87 expected.assert_eq(&actual, true);
88 }
89 DType::U32 => {
90 let expected = into_data_sync::<R, u32>(self.clone());
91 let actual = into_data_sync::<R, u32>(other);
92 expected.assert_eq(&actual, true);
93 }
94 DType::U16 => {
95 let expected = into_data_sync::<R, u16>(self.clone());
96 let actual = into_data_sync::<R, u16>(other);
97 expected.assert_eq(&actual, true);
98 }
99 DType::U8 => {
100 let expected = into_data_sync::<R, u8>(self.clone());
101 let actual = into_data_sync::<R, u8>(other);
102 expected.assert_eq(&actual, true);
103 }
104 DType::Bool => (),
105 DType::QFloat(..) => (),
106 }
107 }
108}
109
110impl<R> core::fmt::Debug for CubeTensor<R>
111where
112 R: CubeRuntime,
113{
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 f.write_fmt(format_args!(
116 "CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}",
117 self.shape,
118 self.device,
119 self.strides,
120 self.dtype.name(),
121 R::name(&self.client),
122 ))
123 }
124}
125
126impl<R> Clone for CubeTensor<R>
127where
128 R: CubeRuntime,
129{
130 fn clone(&self) -> Self {
131 Self {
132 client: self.client.clone(),
133 handle: self.handle.clone(),
134 shape: self.shape.clone(),
135 device: self.device.clone(),
136 strides: self.strides.clone(),
137 dtype: self.dtype,
138 }
139 }
140}
141
142impl<R: CubeRuntime> TensorMetadata for CubeTensor<R> {
143 fn dtype(&self) -> DType {
144 self.dtype
145 }
146
147 fn shape(&self) -> Shape {
148 self.shape.clone()
149 }
150}
151
152impl<R: CubeRuntime> QTensorPrimitive for CubeTensor<R> {
153 fn scheme(&self) -> &burn_tensor::quantization::QuantizationScheme {
154 if let DType::QFloat(scheme) = &self.dtype {
155 scheme
156 } else {
157 panic!(
158 "Quantization scheme is not valid for dtype {:?}",
159 self.dtype,
160 )
161 }
162 }
163}
164
165#[macro_export]
171macro_rules! execute_with_dtype {
172 (float($dtype:expr), $element:ident, $op:expr) => {{
173 match $dtype {
174 burn_tensor::DType::F64 => {
175 type $element = f64;
176 $op
177 }
178 burn_tensor::DType::F32 => {
179 type $element = f32;
180 $op
181 }
182 burn_tensor::DType::Flex32 => {
183 type $element = cubecl::flex32;
184 $op
185 }
186
187 burn_tensor::DType::F16 => {
188 type $element = half::f16;
189 $op
190 }
191 burn_tensor::DType::BF16 => {
192 type $element = half::bf16;
193 $op
194 }
195 _ => unimplemented!("Unsupported dtype {:?}", $dtype),
196 }
197 }};
198
199 (float($lhs_dtype:expr, $rhs_dtype:expr), $element:ident, $op:expr) => {{
200 if $lhs_dtype != $rhs_dtype {
202 panic!(
203 "Data type mismatch (lhs: {:?}, rhs: {:?})",
204 $lhs_dtype, $rhs_dtype
205 );
206 }
207 execute_with_dtype!(float($lhs_dtype), $element, $op)
208 }};
209 ($dtype:expr, $element:ident, $op:expr) => {{
210 match $dtype {
211 burn_tensor::DType::F64 => {
212 type $element = f64;
213 $op
214 }
215 burn_tensor::DType::F32 => {
216 type $element = f32;
217 $op
218 }
219 burn_tensor::DType::Flex32 => {
220 type $element = cubecl::flex32;
221 $op
222 }
223 burn_tensor::DType::F16 => {
224 type $element = half::f16;
225 $op
226 }
227 burn_tensor::DType::BF16 => {
228 type $element = half::bf16;
229 $op
230 }
231 burn_tensor::DType::U64 => {
232 type $element = u64;
233 $op
234 }
235 burn_tensor::DType::U32 => {
236 type $element = u32;
237 $op
238 }
239 burn_tensor::DType::U16 => {
240 type $element = u16;
241 $op
242 }
243 burn_tensor::DType::U8 => {
244 type $element = u8;
245 $op
246 }
247 burn_tensor::DType::I64 => {
248 type $element = i64;
249 $op
250 }
251 burn_tensor::DType::I32 => {
252 type $element = i32;
253 $op
254 }
255 burn_tensor::DType::I16 => {
256 type $element = i16;
257 $op
258 }
259 burn_tensor::DType::I8 => {
260 type $element = i8;
261 $op
262 }
263 burn_tensor::DType::QFloat(_) => {
269 type $element = u32;
270 $op
271 }
272 _ => unimplemented!("Unsupported dtype {:?}", $dtype),
273 }
274 }};
275}
276
277impl<R> CubeTensor<R>
278where
279 R: CubeRuntime,
280{
281 pub fn new_contiguous(
283 client: ComputeClient<R::Server, R::Channel>,
284 device: R::Device,
285 shape: Shape,
286 handle: Handle,
287 dtype: DType,
288 ) -> Self {
289 let ndims = shape.num_dims();
290 let mut strides = vec![0; ndims];
291 let mut current = 1;
292
293 shape
294 .dims
295 .iter()
296 .enumerate()
297 .rev()
298 .for_each(|(index, val)| {
299 strides[index] = current;
300 current *= val;
301 });
302
303 Self {
304 client,
305 handle,
306 shape,
307 strides,
308 device,
309 dtype,
310 }
311 }
312
313 pub fn to_client(
315 &self,
316 client: ComputeClient<R::Server, R::Channel>,
317 device: R::Device,
318 ) -> Self {
319 let bytes = burn_common::reader::try_read_sync(
320 self.client.read_one_async(self.handle.clone().binding()),
321 )
322 .expect("Can only change client synchronously");
323 let handle = client.create(&bytes);
324
325 Self {
326 client,
327 handle,
328 shape: self.shape.clone(),
329 strides: self.strides.clone(),
330 device,
331 dtype: self.dtype,
332 }
333 }
334
335 pub fn as_handle_ref(&self) -> TensorHandleRef<'_, R> {
337 TensorHandleRef {
338 handle: &self.handle,
339 strides: &self.strides,
340 shape: &self.shape.dims,
341 runtime: PhantomData,
342 elem_size: self.elem_size(),
343 }
344 }
345
346 fn elem_size(&self) -> usize {
347 if let DType::QFloat(_) = self.dtype {
348 core::mem::size_of::<u32>()
350 } else {
351 self.dtype.size()
352 }
353 }
354
355 pub fn as_tensor_arg<'a, E: CubeElement>(&'a self, vectorisation: u8) -> TensorArg<'a, R> {
357 let handle: TensorHandleRef<'a, R> = self.as_handle_ref();
358
359 unsafe {
360 TensorArg::from_raw_parts::<E>(
361 handle.handle,
362 handle.strides,
363 handle.shape,
364 vectorisation,
365 )
366 }
367 }
368
369 pub fn as_array_arg<E: CubeElement>(&self, vectorisation: u8) -> ArrayArg<'_, R> {
371 unsafe {
372 ArrayArg::from_raw_parts::<E>(
373 &self.handle,
374 self.handle.size() as usize / core::mem::size_of::<E>(),
375 vectorisation,
376 )
377 }
378 }
379
380 pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool {
381 if !self.handle.can_mut() || !self.is_contiguous_buffer() {
382 return false;
383 }
384 let ndims = self.shape.num_dims();
385
386 for i in 0..ndims {
387 let shape_lhs = self.shape.dims[i];
388 let shape_rhs = rhs.shape.dims[i];
389
390 if shape_lhs < shape_rhs {
392 return false;
393 }
394 }
395
396 true
397 }
398
399 pub fn copy(&self) -> Self {
401 struct Copy;
402
403 #[cube]
404 impl<N: Numeric> NumericUnaryOp<N> for Copy {
405 type Options = ();
406
407 fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
408 input
409 }
410 }
411
412 impl NumericUnaryOpFamily for Copy {
413 type Options<N: Numeric> = ();
414 type Unary<N: Numeric> = Self;
415 }
416
417 let tensor = self.clone();
418
419 execute_with_dtype!(
420 tensor.dtype,
421 E,
422 launch_unary_numeric::<R, E, Copy, _>(tensor, |_| ())
423 )
424 }
425
426 pub fn can_mut(&self) -> bool {
428 self.handle.can_mut()
429 }
430
431 pub fn assert_is_on_same_device(&self, other: &Self) {
433 if self.device != other.device {
434 panic!(
435 "Both tensors should be on the same device {:?} != {:?}",
436 self.device, other.device
437 );
438 }
439 }
440
441 pub fn is_contiguous(&self) -> bool {
448 is_contiguous(&self.shape.dims, &self.strides)
449 }
450
451 pub fn is_contiguous_buffer(&self) -> bool {
454 self.shape.num_elements() * self.dtype.size() == self.handle.size() as usize
455 }
456}
457
458pub(crate) fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
459 if shape.is_empty() {
460 return true;
461 }
462
463 let mut prev_stride = 1;
464 let mut current_num_elems_shape = 1;
465
466 for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() {
467 if *shape == 1 {
468 continue;
469 }
470
471 if i > 0 {
472 if current_num_elems_shape != *stride {
473 return false;
474 }
475
476 if prev_stride >= *stride {
477 return false;
478 }
479 }
480
481 current_num_elems_shape *= shape;
482 prev_stride = *stride;
483 }
484
485 true
486}
487
488#[cfg(test)]
489mod tests {
490 use crate::tensor::base::is_contiguous;
491
492 #[test]
493 fn is_contiguous_basic() {
494 assert!(is_contiguous(&[32, 32], &[32, 1]));
495 }
496
497 #[test]
498 fn is_contiguous_permuted() {
499 assert!(!is_contiguous(&[32, 32], &[1, 32]));
500 }
501
502 #[test]
503 fn is_contiguous_slice() {
504 assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));
505 }
506
507 #[test]
508 fn is_contiguous_4d_positive() {
509 assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));
510 }
511
512 #[test]
513 fn is_contiguous_4d_negative() {
514 assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));
515 }
516}