1use crate::CubeRuntime;
2use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric};
3use burn_backend::quantization::QuantScheme;
4use burn_backend::{DType, QTensorPrimitive, Shape, TensorMetadata};
5use burn_std::{Metadata, strides, tensor::is_contiguous};
6use cubecl::server::Handle;
7use cubecl::std::tensor::TensorHandle;
8use cubecl::{client::ComputeClient, std::tensor::layout::linear::LinearViewLaunch};
9use cubecl::{frontend::Numeric, std::tensor::layout::linear::LinearViewLayoutLaunch};
10use cubecl::{
11 prelude::{TensorBinding, *},
12 std::tensor::layout::linear::LinearViewLayout,
13};
14use std::marker::PhantomData;
15
16use super::QParams;
17
18pub struct CubeTensor<R: CubeRuntime> {
20 pub client: ComputeClient<R>,
22 pub handle: Handle,
24 pub meta: Box<Metadata>,
26 pub device: R::Device,
28 pub dtype: DType,
30 pub qparams: Option<QParams>,
32}
33
34impl<R: CubeRuntime> From<CubeTensor<R>> for TensorHandle<R> {
35 fn from(val: CubeTensor<R>) -> Self {
36 TensorHandle::new(
37 val.handle.clone(),
38 val.meta.shape().clone(),
39 val.meta.strides().clone(),
40 val.dtype,
41 )
42 }
43}
44
45impl<R: CubeRuntime> cubecl::tune::AutotuneOutput for CubeTensor<R> {
46 #[cfg(feature = "autotune-checks")]
47 fn check_equivalence(&self, other: Self) {
48 use crate::ops::into_data_sync;
49 use burn_backend::Tolerance;
50
51 let expected = into_data_sync::<R>(self.clone());
52 let actual = into_data_sync::<R>(other);
53 expected.assert_approx_eq::<f32>(&actual, Tolerance::permissive());
54 }
55}
56
57impl<R> core::fmt::Debug for CubeTensor<R>
69where
70 R: CubeRuntime,
71{
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 f.write_fmt(format_args!(
74 "CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}",
75 self.meta.shape(),
76 self.device,
77 self.meta.strides(),
78 self.dtype.name(),
79 R::name(&self.client),
80 ))
81 }
82}
83
84impl<R> Clone for CubeTensor<R>
85where
86 R: CubeRuntime,
87{
88 fn clone(&self) -> Self {
89 Self {
90 client: self.client.clone(),
91 handle: self.handle.clone(),
92 meta: self.meta.clone(),
93 device: self.device.clone(),
94 dtype: self.dtype,
95 qparams: self.qparams.clone(),
96 }
97 }
98}
99
100impl<R: CubeRuntime> TensorMetadata for CubeTensor<R> {
101 fn dtype(&self) -> DType {
102 self.dtype
103 }
104
105 fn shape(&self) -> Shape {
106 self.meta.shape().clone()
107 }
108
109 fn rank(&self) -> usize {
110 self.meta.rank()
111 }
112}
113
114impl<R: CubeRuntime> QTensorPrimitive for CubeTensor<R> {
115 fn scheme(&self) -> &QuantScheme {
116 if let DType::QFloat(scheme) = &self.dtype {
117 scheme
118 } else {
119 panic!(
120 "Quantization scheme is not valid for dtype {:?}",
121 self.dtype,
122 )
123 }
124 }
125}
126
127impl<R> CubeTensor<R>
128where
129 R: CubeRuntime,
130{
131 pub fn new(
133 client: ComputeClient<R>,
134 handle: Handle,
135 metadata: Metadata,
136 device: R::Device,
137 dtype: DType,
138 ) -> Self {
139 CubeTensor {
140 client,
141 handle,
142 meta: Box::new(metadata),
143 device,
144 dtype,
145 qparams: None,
146 }
147 }
148
149 pub fn new_contiguous(
151 client: ComputeClient<R>,
152 device: R::Device,
153 shape: Shape,
154 handle: Handle,
155 dtype: DType,
156 ) -> Self {
157 let ndims = shape.num_dims();
158 let mut strides = strides![0; ndims];
159 let mut current = 1;
160
161 shape.iter().enumerate().rev().for_each(|(index, val)| {
162 strides[index] = current;
163 current *= val;
164 });
165
166 Self {
167 client,
168 handle,
169 meta: Box::new(Metadata::new(shape, strides)),
170 device,
171 dtype,
172 qparams: None,
173 }
174 }
175
176 pub fn to_client(&mut self, client: ComputeClient<R>, device: R::Device) -> Self {
178 let desc = self.handle.clone().copy_descriptor(
179 self.meta.shape().clone(),
180 self.meta.strides().clone(),
181 self.elem_size(),
182 );
183 let handle = self
184 .client
185 .to_client_tensor(desc, &client, self.dtype.into());
186
187 Self {
188 client,
189 handle,
190 meta: Box::new(Metadata::new(self.shape(), self.meta.strides().clone())),
191 device,
192 dtype: self.dtype,
193 qparams: self.qparams.clone(),
194 }
195 }
196
197 pub fn binding(self) -> TensorBinding<R> {
199 TensorBinding {
200 handle: self.handle.binding(),
201 strides: self.meta.strides,
202 shape: self.meta.shape,
203 runtime: PhantomData,
204 }
205 }
206
207 pub fn elem_size(&self) -> usize {
209 self.dtype.size()
210 }
211
212 pub fn into_tensor_arg(self) -> TensorArg<R> {
214 self.binding().into_tensor_arg()
215 }
216
217 pub fn into_array_arg(self) -> ArrayArg<R> {
219 self.into_tensor_arg().into_array_arg()
220 }
221
222 pub fn as_tensor_alias(&self, input_pos: usize) -> TensorArg<R> {
224 TensorArg::Alias {
225 input_pos,
226 strides: self.meta.strides().clone(),
227 shape: self.meta.shape().clone(),
228 }
229 }
230
231 pub fn into_linear_view(self) -> LinearViewLaunch<R> {
233 let layout = LinearViewLayoutLaunch::new();
234 let buffer = self.into_tensor_arg();
235 LinearViewLaunch::new_tensor::<LinearViewLayout>(buffer, layout)
236 }
237
238 pub fn as_linear_view_alias(&self, input_pos: usize) -> LinearViewLaunch<R> {
240 let layout = LinearViewLayoutLaunch::new();
241 let buffer = self.as_tensor_alias(input_pos);
242 LinearViewLaunch::new_tensor::<LinearViewLayout>(buffer, layout)
243 }
244
245 pub fn into_linear_view_like(self, reference: &Self) -> LinearViewLaunch<R> {
247 let layout = LinearViewLayoutLaunch::from_reference_shape(reference.shape());
248 let buffer = self.into_tensor_arg();
249 LinearViewLaunch::new_tensor::<LinearViewLayout>(buffer, layout)
250 }
251
252 pub fn required_address_type(&self) -> AddressType {
254 match self.try_scheme() {
255 Some(scheme) => {
256 let len = self.handle.size() as usize * 8 / scheme.size_bits_value();
257 AddressType::from_len(len)
258 }
259 None => AddressType::from_len(self.handle.size() as usize / self.dtype.size()),
260 }
261 }
262
263 pub fn try_scheme(&self) -> Option<&QuantScheme> {
265 match &self.dtype {
266 DType::QFloat(scheme) => Some(scheme),
267 _ => None,
268 }
269 }
270
271 pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool {
272 if !self.handle.can_mut() || !self.is_nonoverlapping() {
273 return false;
274 }
275 let ndims = self.meta.num_dims();
276
277 for i in 0..ndims {
278 let shape_lhs = self.meta.shape()[i];
279 let shape_rhs = rhs.meta.shape()[i];
280
281 if shape_lhs < shape_rhs {
283 return false;
284 }
285 }
286
287 true
288 }
289
290 pub fn copy(&self) -> Self {
292 struct Copy;
293
294 #[cube]
295 impl<T: Numeric, N: Size> NumericUnaryOp<T, N> for Copy {
296 type Options = ();
297
298 fn execute(input: Vector<T, N>, _options: &Self::Options) -> Vector<T, N> {
299 input
300 }
301 }
302
303 impl NumericUnaryOpFamily for Copy {
304 type Options = ();
305 type Unary<T: Numeric, N: Size> = Self;
306 }
307
308 let tensor = self.clone();
309 launch_unary_numeric::<R, Copy, _>(tensor, |_| ())
310 }
311
312 pub fn can_mut(&self) -> bool {
314 self.handle.can_mut()
315 }
316
317 pub fn assert_is_on_same_device(&self, other: &Self) {
319 if self.device != other.device {
320 panic!(
321 "Both tensors should be on the same device {:?} != {:?}",
322 self.device, other.device
323 );
324 }
325 }
326
327 pub fn is_contiguous(&self) -> bool {
334 is_contiguous(self.meta.shape(), self.meta.strides())
335 }
336
337 pub fn is_contiguous_buffer(&self) -> bool {
340 self.meta.shape().num_elements() * self.dtype.size() == self.handle.size() as usize
341 }
342
343 pub fn is_nonoverlapping(&self) -> bool {
345 let shape = self.meta.shape();
346 let strides = self.meta.strides();
347
348 if strides.contains(&0) {
349 return false;
350 }
351 let rank = self.rank();
352 if rank > 1 {
353 let mut dims = shape.iter().zip(strides.iter()).collect::<Vec<_>>();
354 dims.sort_by_key(|(_, stride)| **stride);
355
356 let mut max_offset = 0;
357 for (shape, stride) in dims.into_iter() {
358 if *stride <= max_offset && *shape != 1 {
359 return false;
360 }
361
362 max_offset += (*shape - 1) * *stride;
363 }
364 }
365 true
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[test]
374 fn is_contiguous_non_increasing() {
375 assert!(is_contiguous(&[3, 1], &[1, 1]));
376 }
377
378 #[test]
379 fn is_contiguous_basic() {
380 assert!(is_contiguous(&[32, 32], &[32, 1]));
381 }
382
383 #[test]
384 fn is_contiguous_permuted() {
385 assert!(!is_contiguous(&[32, 32], &[1, 32]));
386 }
387
388 #[test]
389 fn is_contiguous_slice() {
390 assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));
391 }
392
393 #[test]
394 fn is_contiguous_4d_positive() {
395 assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));
396 }
397
398 #[test]
399 fn is_contiguous_4d_negative() {
400 assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));
401 }
402
403 #[test]
405 fn is_contiguous_4d_unit_shape() {
406 assert!(!is_contiguous(&[1, 1, 1, 9], &[72, 1, 72, 8]));
407 }
408}