1use crate::CubeRuntime;
2use crate::element::CubeElement;
3use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric};
4use burn_std::tensor::is_contiguous;
5use burn_tensor::quantization::QTensorPrimitive;
6use burn_tensor::{DType, Shape, TensorMetadata};
7use cubecl::client::ComputeClient;
8use cubecl::frontend::Numeric;
9use cubecl::prelude::{TensorHandleRef, *};
10use cubecl::server::Handle;
11use cubecl::std::tensor::TensorHandle;
12use std::marker::PhantomData;
13
14use super::QParams;
15
16pub struct CubeTensor<R: CubeRuntime> {
18 pub client: ComputeClient<R>,
20 pub handle: Handle,
22 pub shape: Shape,
24 pub device: R::Device,
26 pub strides: Vec<usize>,
28 pub dtype: DType,
30 pub qparams: Option<QParams>,
32}
33
34impl<R: CubeRuntime> From<CubeTensor<R>> for TensorHandle<R> {
35 fn from(val: CubeTensor<R>) -> Self {
36 TensorHandle::new(
37 val.handle,
38 val.shape.to_vec(),
39 val.strides.to_vec(),
40 val.dtype.into(),
41 )
42 }
43}
44
45impl<R: CubeRuntime> cubecl::tune::AutotuneOutput for CubeTensor<R> {
46 #[cfg(feature = "autotune-checks")]
47 fn check_equivalence(&self, other: Self) {
48 use burn_tensor::Tolerance;
49
50 use crate::ops::into_data_sync;
51
52 match self.dtype {
53 DType::F64 => {
54 let expected = into_data_sync::<R, f64>(self.clone());
55 let actual = into_data_sync::<R, f64>(other);
56 expected.assert_approx_eq::<f64>(&actual, Tolerance::permissive());
57 }
58 DType::F32 | DType::Flex32 => {
59 let expected = into_data_sync::<R, f32>(self.clone());
60 let actual = into_data_sync::<R, f32>(other);
61 expected.assert_approx_eq::<f32>(&actual, Tolerance::permissive());
62 }
63 DType::F16 => {
64 let expected = into_data_sync::<R, half::f16>(self.clone());
65 let actual = into_data_sync::<R, half::f16>(other);
66 expected.assert_approx_eq::<half::f16>(&actual, Tolerance::permissive());
67 }
68 DType::BF16 => {
69 let expected = into_data_sync::<R, half::bf16>(self.clone());
70 let actual = into_data_sync::<R, half::bf16>(other);
71 expected.assert_approx_eq::<half::bf16>(&actual, Tolerance::permissive());
72 }
73 DType::I64 => {
74 let expected = into_data_sync::<R, i64>(self.clone());
75 let actual = into_data_sync::<R, i64>(other);
76 expected.assert_eq(&actual, true);
77 }
78 DType::I32 => {
79 let expected = into_data_sync::<R, i32>(self.clone());
80 let actual = into_data_sync::<R, i32>(other);
81 expected.assert_eq(&actual, true);
82 }
83 DType::I16 => {
84 let expected = into_data_sync::<R, i16>(self.clone());
85 let actual = into_data_sync::<R, i16>(other);
86 expected.assert_eq(&actual, true);
87 }
88 DType::I8 => {
89 let expected = into_data_sync::<R, i8>(self.clone());
90 let actual = into_data_sync::<R, i8>(other);
91 expected.assert_eq(&actual, true);
92 }
93 DType::U64 => {
94 let expected = into_data_sync::<R, u64>(self.clone());
95 let actual = into_data_sync::<R, u64>(other);
96 expected.assert_eq(&actual, true);
97 }
98 DType::U32 => {
99 let expected = into_data_sync::<R, u32>(self.clone());
100 let actual = into_data_sync::<R, u32>(other);
101 expected.assert_eq(&actual, true);
102 }
103 DType::U16 => {
104 let expected = into_data_sync::<R, u16>(self.clone());
105 let actual = into_data_sync::<R, u16>(other);
106 expected.assert_eq(&actual, true);
107 }
108 DType::U8 => {
109 let expected = into_data_sync::<R, u8>(self.clone());
110 let actual = into_data_sync::<R, u8>(other);
111 expected.assert_eq(&actual, true);
112 }
113 DType::Bool => (),
114 DType::QFloat(..) => (),
115 }
116 }
117}
118
119impl<R> core::fmt::Debug for CubeTensor<R>
120where
121 R: CubeRuntime,
122{
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 f.write_fmt(format_args!(
125 "CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}",
126 self.shape,
127 self.device,
128 self.strides,
129 self.dtype.name(),
130 R::name(&self.client),
131 ))
132 }
133}
134
135impl<R> Clone for CubeTensor<R>
136where
137 R: CubeRuntime,
138{
139 fn clone(&self) -> Self {
140 Self {
141 client: self.client.clone(),
142 handle: self.handle.clone(),
143 shape: self.shape.clone(),
144 device: self.device.clone(),
145 strides: self.strides.clone(),
146 dtype: self.dtype,
147 qparams: self.qparams.clone(),
148 }
149 }
150}
151
152impl<R: CubeRuntime> TensorMetadata for CubeTensor<R> {
153 fn dtype(&self) -> DType {
154 self.dtype
155 }
156
157 fn shape(&self) -> Shape {
158 self.shape.clone()
159 }
160
161 fn rank(&self) -> usize {
162 self.shape.num_dims()
163 }
164}
165
166impl<R: CubeRuntime> QTensorPrimitive for CubeTensor<R> {
167 fn scheme(&self) -> &burn_tensor::quantization::QuantScheme {
168 if let DType::QFloat(scheme) = &self.dtype {
169 scheme
170 } else {
171 panic!(
172 "Quantization scheme is not valid for dtype {:?}",
173 self.dtype,
174 )
175 }
176 }
177}
178
179impl<R> CubeTensor<R>
180where
181 R: CubeRuntime,
182{
183 pub fn new(
185 client: ComputeClient<R>,
186 handle: Handle,
187 shape: Shape,
188 device: R::Device,
189 strides: Vec<usize>,
190 dtype: DType,
191 ) -> Self {
192 CubeTensor {
193 client,
194 handle,
195 shape,
196 device,
197 strides,
198 dtype,
199 qparams: None,
200 }
201 }
202
203 pub fn new_contiguous(
205 client: ComputeClient<R>,
206 device: R::Device,
207 shape: Shape,
208 handle: Handle,
209 dtype: DType,
210 ) -> Self {
211 let ndims = shape.num_dims();
212 let mut strides = vec![0; ndims];
213 let mut current = 1;
214
215 shape
216 .dims
217 .iter()
218 .enumerate()
219 .rev()
220 .for_each(|(index, val)| {
221 strides[index] = current;
222 current *= val;
223 });
224
225 Self {
226 client,
227 handle,
228 shape,
229 strides,
230 device,
231 dtype,
232 qparams: None,
233 }
234 }
235
236 pub fn to_client(&self, client: ComputeClient<R>, device: R::Device) -> Self {
238 let desc = self
239 .handle
240 .copy_descriptor(&self.shape.dims, &self.strides, self.elem_size());
241 let alloc = self.client.to_client_tensor(desc, &client);
242
243 Self {
244 client,
245 handle: alloc.handle,
246 shape: self.shape.clone(),
247 device,
248 strides: alloc.strides,
249 dtype: self.dtype,
250 qparams: self.qparams.clone(),
251 }
252 }
253
254 pub fn as_handle_ref(&self) -> TensorHandleRef<'_, R> {
256 TensorHandleRef {
257 handle: &self.handle,
258 strides: &self.strides,
259 shape: &self.shape.dims,
260 runtime: PhantomData,
261 elem_size: self.elem_size(),
262 }
263 }
264
265 pub fn elem_size(&self) -> usize {
267 self.dtype.size()
268 }
269
270 pub fn as_tensor_arg<'a>(&'a self, line_size: u8) -> TensorArg<'a, R> {
272 let size = self.dtype.size();
273 let handle: TensorHandleRef<'a, R> = self.as_handle_ref();
274
275 unsafe {
276 TensorArg::from_raw_parts_and_size(
277 handle.handle,
278 handle.strides,
279 handle.shape,
280 line_size,
281 size,
282 )
283 }
284 }
285
286 pub fn as_array_arg<E: CubeElement>(&self, vectorisation: u8) -> ArrayArg<'_, R> {
288 unsafe {
289 ArrayArg::from_raw_parts::<E>(
290 &self.handle,
291 self.handle.size() as usize / core::mem::size_of::<E>(),
292 vectorisation,
293 )
294 }
295 }
296
297 pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool {
298 if !self.handle.can_mut() || !self.is_contiguous_buffer() {
299 return false;
300 }
301 let ndims = self.shape.num_dims();
302
303 for i in 0..ndims {
304 let shape_lhs = self.shape[i];
305 let shape_rhs = rhs.shape[i];
306
307 if shape_lhs < shape_rhs {
309 return false;
310 }
311 }
312
313 true
314 }
315
316 pub fn copy(&self) -> Self {
318 struct Copy;
319
320 #[cube]
321 impl<N: Numeric> NumericUnaryOp<N> for Copy {
322 type Options = ();
323
324 fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
325 input
326 }
327 }
328
329 impl NumericUnaryOpFamily for Copy {
330 type Options = ();
331 type Unary<N: Numeric> = Self;
332 }
333
334 let tensor = self.clone();
335 launch_unary_numeric::<R, Copy, _>(tensor, |_| ())
336 }
337
338 pub fn can_mut(&self) -> bool {
340 self.handle.can_mut()
341 }
342
343 pub fn assert_is_on_same_device(&self, other: &Self) {
345 if self.device != other.device {
346 panic!(
347 "Both tensors should be on the same device {:?} != {:?}",
348 self.device, other.device
349 );
350 }
351 }
352
353 pub fn is_contiguous(&self) -> bool {
360 is_contiguous(&self.shape.dims, &self.strides)
361 }
362
363 pub fn is_contiguous_buffer(&self) -> bool {
366 self.shape.num_elements() * self.dtype.size() == self.handle.size() as usize
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373
374 #[test]
375 fn is_contiguous_non_increasing() {
376 assert!(is_contiguous(&[3, 1], &[1, 1]));
377 }
378
379 #[test]
380 fn is_contiguous_basic() {
381 assert!(is_contiguous(&[32, 32], &[32, 1]));
382 }
383
384 #[test]
385 fn is_contiguous_permuted() {
386 assert!(!is_contiguous(&[32, 32], &[1, 32]));
387 }
388
389 #[test]
390 fn is_contiguous_slice() {
391 assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));
392 }
393
394 #[test]
395 fn is_contiguous_4d_positive() {
396 assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));
397 }
398
399 #[test]
400 fn is_contiguous_4d_negative() {
401 assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));
402 }
403
404 #[test]
406 fn is_contiguous_4d_unit_shape() {
407 assert!(!is_contiguous(&[1, 1, 1, 9], &[72, 1, 72, 8]));
408 }
409}