1use crate::CubeRuntime;
2use crate::element::CubeElement;
3use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric};
4use burn_common::tensor::is_contiguous;
5use burn_tensor::quantization::QTensorPrimitive;
6use burn_tensor::{DType, Shape, TensorMetadata};
7use cubecl::client::ComputeClient;
8use cubecl::frontend::Numeric;
9use cubecl::prelude::{TensorHandleRef, *};
10use cubecl::server::Handle;
11use cubecl::std::tensor::TensorHandle;
12use std::marker::PhantomData;
13
14use super::QParams;
15
16pub struct CubeTensor<R: CubeRuntime> {
18 pub client: ComputeClient<R::Server>,
20 pub handle: Handle,
22 pub shape: Shape,
24 pub device: R::Device,
26 pub strides: Vec<usize>,
28 pub dtype: DType,
30 pub qparams: Option<QParams>,
32}
33
34impl<R: CubeRuntime> From<CubeTensor<R>> for TensorHandle<R> {
35 fn from(val: CubeTensor<R>) -> Self {
36 TensorHandle::new(
37 val.handle,
38 val.shape.to_vec(),
39 val.strides.to_vec(),
40 val.dtype.into(),
41 )
42 }
43}
44
45impl<R: CubeRuntime> cubecl::tune::AutotuneOutput for CubeTensor<R> {
46 #[cfg(feature = "autotune-checks")]
47 fn check_equivalence(&self, other: Self) {
48 use burn_tensor::Tolerance;
49
50 use crate::ops::into_data_sync;
51
52 match self.dtype {
53 DType::F64 => {
54 let expected = into_data_sync::<R, f64>(self.clone());
55 let actual = into_data_sync::<R, f64>(other);
56 expected.assert_approx_eq::<f64>(&actual, Tolerance::permissive());
57 }
58 DType::F32 | DType::Flex32 => {
59 let expected = into_data_sync::<R, f32>(self.clone());
60 let actual = into_data_sync::<R, f32>(other);
61 expected.assert_approx_eq::<f32>(&actual, Tolerance::permissive());
62 }
63 DType::F16 => {
64 let expected = into_data_sync::<R, half::f16>(self.clone());
65 let actual = into_data_sync::<R, half::f16>(other);
66 expected.assert_approx_eq::<half::f16>(&actual, Tolerance::permissive());
67 }
68 DType::BF16 => {
69 let expected = into_data_sync::<R, half::bf16>(self.clone());
70 let actual = into_data_sync::<R, half::bf16>(other);
71 expected.assert_approx_eq::<half::bf16>(&actual, Tolerance::permissive());
72 }
73 DType::I64 => {
74 let expected = into_data_sync::<R, i64>(self.clone());
75 let actual = into_data_sync::<R, i64>(other);
76 expected.assert_eq(&actual, true);
77 }
78 DType::I32 => {
79 let expected = into_data_sync::<R, i32>(self.clone());
80 let actual = into_data_sync::<R, i32>(other);
81 expected.assert_eq(&actual, true);
82 }
83 DType::I16 => {
84 let expected = into_data_sync::<R, i16>(self.clone());
85 let actual = into_data_sync::<R, i16>(other);
86 expected.assert_eq(&actual, true);
87 }
88 DType::I8 => {
89 let expected = into_data_sync::<R, i8>(self.clone());
90 let actual = into_data_sync::<R, i8>(other);
91 expected.assert_eq(&actual, true);
92 }
93 DType::U64 => {
94 let expected = into_data_sync::<R, u64>(self.clone());
95 let actual = into_data_sync::<R, u64>(other);
96 expected.assert_eq(&actual, true);
97 }
98 DType::U32 => {
99 let expected = into_data_sync::<R, u32>(self.clone());
100 let actual = into_data_sync::<R, u32>(other);
101 expected.assert_eq(&actual, true);
102 }
103 DType::U16 => {
104 let expected = into_data_sync::<R, u16>(self.clone());
105 let actual = into_data_sync::<R, u16>(other);
106 expected.assert_eq(&actual, true);
107 }
108 DType::U8 => {
109 let expected = into_data_sync::<R, u8>(self.clone());
110 let actual = into_data_sync::<R, u8>(other);
111 expected.assert_eq(&actual, true);
112 }
113 DType::Bool => (),
114 DType::QFloat(..) => (),
115 }
116 }
117}
118
119impl<R> core::fmt::Debug for CubeTensor<R>
120where
121 R: CubeRuntime,
122{
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 f.write_fmt(format_args!(
125 "CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}",
126 self.shape,
127 self.device,
128 self.strides,
129 self.dtype.name(),
130 R::name(&self.client),
131 ))
132 }
133}
134
135impl<R> Clone for CubeTensor<R>
136where
137 R: CubeRuntime,
138{
139 fn clone(&self) -> Self {
140 Self {
141 client: self.client.clone(),
142 handle: self.handle.clone(),
143 shape: self.shape.clone(),
144 device: self.device.clone(),
145 strides: self.strides.clone(),
146 dtype: self.dtype,
147 qparams: self.qparams.clone(),
148 }
149 }
150}
151
152impl<R: CubeRuntime> TensorMetadata for CubeTensor<R> {
153 fn dtype(&self) -> DType {
154 self.dtype
155 }
156
157 fn shape(&self) -> Shape {
158 self.shape.clone()
159 }
160
161 fn rank(&self) -> usize {
162 self.shape.num_dims()
163 }
164}
165
166impl<R: CubeRuntime> QTensorPrimitive for CubeTensor<R> {
167 fn scheme(&self) -> &burn_tensor::quantization::QuantScheme {
168 if let DType::QFloat(scheme) = &self.dtype {
169 scheme
170 } else {
171 panic!(
172 "Quantization scheme is not valid for dtype {:?}",
173 self.dtype,
174 )
175 }
176 }
177}
178
179impl<R> CubeTensor<R>
180where
181 R: CubeRuntime,
182{
183 pub fn new(
185 client: ComputeClient<R::Server>,
186 handle: Handle,
187 shape: Shape,
188 device: R::Device,
189 strides: Vec<usize>,
190 dtype: DType,
191 ) -> Self {
192 CubeTensor {
193 client,
194 handle,
195 shape,
196 device,
197 strides,
198 dtype,
199 qparams: None,
200 }
201 }
202
203 pub fn new_contiguous(
205 client: ComputeClient<R::Server>,
206 device: R::Device,
207 shape: Shape,
208 handle: Handle,
209 dtype: DType,
210 ) -> Self {
211 let ndims = shape.num_dims();
212 let mut strides = vec![0; ndims];
213 let mut current = 1;
214
215 shape
216 .dims
217 .iter()
218 .enumerate()
219 .rev()
220 .for_each(|(index, val)| {
221 strides[index] = current;
222 current *= val;
223 });
224
225 Self {
226 client,
227 handle,
228 shape,
229 strides,
230 device,
231 dtype,
232 qparams: None,
233 }
234 }
235
236 pub fn to_client(&self, client: ComputeClient<R::Server>, device: R::Device) -> Self {
238 let desc = self
239 .handle
240 .copy_descriptor(&self.shape.dims, &self.strides, self.elem_size());
241 let alloc = self.client.to_client_tensor(desc, &client);
242
243 Self {
244 client,
245 handle: alloc.handle,
246 shape: self.shape.clone(),
247 device,
248 strides: alloc.strides,
249 dtype: self.dtype,
250 qparams: self.qparams.clone(),
251 }
252 }
253
254 pub fn as_handle_ref(&self) -> TensorHandleRef<'_, R> {
256 TensorHandleRef {
257 handle: &self.handle,
258 strides: &self.strides,
259 shape: &self.shape.dims,
260 runtime: PhantomData,
261 elem_size: self.elem_size(),
262 }
263 }
264
265 pub fn elem_size(&self) -> usize {
267 if let DType::QFloat(_) = self.dtype {
268 core::mem::size_of::<u32>()
270 } else {
271 self.dtype.size()
272 }
273 }
274
275 pub fn as_tensor_arg<'a>(&'a self, line_size: u8) -> TensorArg<'a, R> {
277 let size = self.dtype.size();
278 let handle: TensorHandleRef<'a, R> = self.as_handle_ref();
279
280 unsafe {
281 TensorArg::from_raw_parts_and_size(
282 handle.handle,
283 handle.strides,
284 handle.shape,
285 line_size,
286 size,
287 )
288 }
289 }
290
291 pub fn as_array_arg<E: CubeElement>(&self, vectorisation: u8) -> ArrayArg<'_, R> {
293 unsafe {
294 ArrayArg::from_raw_parts::<E>(
295 &self.handle,
296 self.handle.size() as usize / core::mem::size_of::<E>(),
297 vectorisation,
298 )
299 }
300 }
301
302 pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool {
303 if !self.handle.can_mut() || !self.is_contiguous_buffer() {
304 return false;
305 }
306 let ndims = self.shape.num_dims();
307
308 for i in 0..ndims {
309 let shape_lhs = self.shape[i];
310 let shape_rhs = rhs.shape[i];
311
312 if shape_lhs < shape_rhs {
314 return false;
315 }
316 }
317
318 true
319 }
320
321 pub fn copy(&self) -> Self {
323 struct Copy;
324
325 #[cube]
326 impl<N: Numeric> NumericUnaryOp<N> for Copy {
327 type Options = ();
328
329 fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
330 input
331 }
332 }
333
334 impl NumericUnaryOpFamily for Copy {
335 type Options = ();
336 type Unary<N: Numeric> = Self;
337 }
338
339 let tensor = self.clone();
340 launch_unary_numeric::<R, Copy, _>(tensor, |_| ())
341 }
342
343 pub fn can_mut(&self) -> bool {
345 self.handle.can_mut()
346 }
347
348 pub fn assert_is_on_same_device(&self, other: &Self) {
350 if self.device != other.device {
351 panic!(
352 "Both tensors should be on the same device {:?} != {:?}",
353 self.device, other.device
354 );
355 }
356 }
357
358 pub fn is_contiguous(&self) -> bool {
365 is_contiguous(&self.shape.dims, &self.strides)
366 }
367
368 pub fn is_contiguous_buffer(&self) -> bool {
371 self.shape.num_elements() * self.dtype.size() == self.handle.size() as usize
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378
379 #[test]
380 fn is_contiguous_non_increasing() {
381 assert!(is_contiguous(&[3, 1], &[1, 1]));
382 }
383
384 #[test]
385 fn is_contiguous_basic() {
386 assert!(is_contiguous(&[32, 32], &[32, 1]));
387 }
388
389 #[test]
390 fn is_contiguous_permuted() {
391 assert!(!is_contiguous(&[32, 32], &[1, 32]));
392 }
393
394 #[test]
395 fn is_contiguous_slice() {
396 assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));
397 }
398
399 #[test]
400 fn is_contiguous_4d_positive() {
401 assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));
402 }
403
404 #[test]
405 fn is_contiguous_4d_negative() {
406 assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));
407 }
408
409 #[test]
411 fn is_contiguous_4d_unit_shape() {
412 assert!(!is_contiguous(&[1, 1, 1, 9], &[72, 1, 72, 8]));
413 }
414}