1use crate::element::JitElement;
2use crate::kernel::{launch_unary_numeric, NumericUnaryOp, NumericUnaryOpFamily};
3use crate::JitRuntime;
4use burn_tensor::quantization::QTensorPrimitive;
5use burn_tensor::{DType, Shape, TensorMetadata};
6use cubecl::client::ComputeClient;
7use cubecl::frontend::Numeric;
8use cubecl::linalg::tensor::TensorHandle;
9use cubecl::prelude::{TensorHandleRef, *};
10use cubecl::server::Handle;
11use std::marker::PhantomData;
12
13#[derive(new)]
15pub struct JitTensor<R: JitRuntime> {
16 pub client: ComputeClient<R::Server, R::Channel>,
18 pub handle: Handle,
20 pub shape: Shape,
22 pub device: R::Device,
24 pub strides: Vec<usize>,
26 pub(crate) dtype: DType,
27}
28
29impl<R: JitRuntime, E: JitElement> From<JitTensor<R>> for TensorHandle<R, E> {
30 fn from(val: JitTensor<R>) -> Self {
31 TensorHandle::new(val.shape.dims.to_vec(), val.strides.to_vec(), val.handle)
32 }
33}
34
35impl<R> core::fmt::Debug for JitTensor<R>
36where
37 R: JitRuntime,
38{
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 f.write_fmt(format_args!(
41 "JitTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}",
42 self.shape,
43 self.device,
44 self.strides,
45 self.dtype.name(),
46 R::name(),
47 ))
48 }
49}
50
51impl<R> Clone for JitTensor<R>
52where
53 R: JitRuntime,
54{
55 fn clone(&self) -> Self {
56 Self {
57 client: self.client.clone(),
58 handle: self.handle.clone(),
59 shape: self.shape.clone(),
60 device: self.device.clone(),
61 strides: self.strides.clone(),
62 dtype: self.dtype,
63 }
64 }
65}
66
67impl<R: JitRuntime> TensorMetadata for JitTensor<R> {
68 fn dtype(&self) -> DType {
69 self.dtype
70 }
71
72 fn shape(&self) -> Shape {
73 self.shape.clone()
74 }
75}
76
77impl<R: JitRuntime> QTensorPrimitive for JitTensor<R> {
78 fn scheme(&self) -> &burn_tensor::quantization::QuantizationScheme {
79 if let DType::QFloat(scheme) = &self.dtype {
80 scheme
81 } else {
82 panic!(
83 "Quantization scheme is not valid for dtype {:?}",
84 self.dtype,
85 )
86 }
87 }
88}
89
90#[macro_export]
96macro_rules! execute_with_dtype {
97 (float($dtype:expr), $element:ident, $op:expr) => {{
98 match $dtype {
99 burn_tensor::DType::F64 => {
100 type $element = f64;
101 $op
102 }
103 burn_tensor::DType::F32 => {
104 type $element = f32;
105 $op
106 }
107 burn_tensor::DType::F16 => {
108 type $element = half::f16;
109 $op
110 }
111 burn_tensor::DType::BF16 => {
112 type $element = half::bf16;
113 $op
114 }
115 _ => unimplemented!("Unsupported dtype"),
116 }
117 }};
118
119 (float($lhs_dtype:expr, $rhs_dtype:expr), $element:ident, $op:expr) => {{
120 if $lhs_dtype != $rhs_dtype {
122 panic!(
123 "Data type mismatch (lhs: {:?}, rhs: {:?})",
124 $lhs_dtype, $rhs_dtype
125 );
126 }
127 execute_with_dtype!(float($lhs_dtype), $element, $op)
128 }};
129 ($dtype:expr, $element:ident, $op:expr) => {{
130 match $dtype {
131 burn_tensor::DType::F64 => {
132 type $element = f64;
133 $op
134 }
135 burn_tensor::DType::F32 => {
136 type $element = f32;
137 $op
138 }
139 burn_tensor::DType::F16 => {
140 type $element = half::f16;
141 $op
142 }
143 burn_tensor::DType::BF16 => {
144 type $element = half::bf16;
145 $op
146 }
147 burn_tensor::DType::U64 => {
148 type $element = u64;
149 $op
150 }
151 burn_tensor::DType::U32 => {
152 type $element = u32;
153 $op
154 }
155 burn_tensor::DType::U16 => {
156 type $element = u16;
157 $op
158 }
159 burn_tensor::DType::U8 => {
160 type $element = u8;
161 $op
162 }
163 burn_tensor::DType::I64 => {
164 type $element = i64;
165 $op
166 }
167 burn_tensor::DType::I32 => {
168 type $element = i32;
169 $op
170 }
171 burn_tensor::DType::I16 => {
172 type $element = i16;
173 $op
174 }
175 burn_tensor::DType::I8 => {
176 type $element = i8;
177 $op
178 }
179 _ => unimplemented!("Unsupported dtype"),
189 }
190 }};
191}
192
193impl<R> JitTensor<R>
194where
195 R: JitRuntime,
196{
197 pub fn new_contiguous(
199 client: ComputeClient<R::Server, R::Channel>,
200 device: R::Device,
201 shape: Shape,
202 handle: Handle,
203 dtype: DType,
204 ) -> Self {
205 let ndims = shape.num_dims();
206 let mut strides = vec![0; ndims];
207
208 let mut current = 1;
209 shape
210 .dims
211 .iter()
212 .enumerate()
213 .rev()
214 .for_each(|(index, val)| {
215 strides[index] = current;
216 current *= val;
217 });
218
219 Self {
220 client,
221 handle,
222 shape,
223 strides,
224 device,
225 dtype,
226 }
227 }
228
229 pub fn to_client(
231 &self,
232 client: ComputeClient<R::Server, R::Channel>,
233 device: R::Device,
234 ) -> Self {
235 let bytes = burn_common::reader::try_read_sync(
236 self.client.read_one_async(self.handle.clone().binding()),
237 )
238 .expect("Can only change client synchronously");
239 let handle = client.create(&bytes);
240
241 Self {
242 client,
243 handle,
244 shape: self.shape.clone(),
245 strides: self.strides.clone(),
246 device,
247 dtype: self.dtype,
248 }
249 }
250
251 pub fn as_handle_ref(&self) -> TensorHandleRef<'_, R> {
253 TensorHandleRef {
254 handle: &self.handle,
255 strides: &self.strides,
256 shape: &self.shape.dims,
257 runtime: PhantomData,
258 elem_size: self.elem_size(),
259 }
260 }
261
262 fn elem_size(&self) -> usize {
263 if let DType::QFloat(_) = self.dtype {
264 core::mem::size_of::<u32>()
266 } else {
267 self.dtype.size()
268 }
269 }
270
271 pub fn as_tensor_arg<'a, E: JitElement>(&'a self, vectorisation: u8) -> TensorArg<'a, R> {
273 let handle: TensorHandleRef<'a, R> = self.as_handle_ref();
274
275 unsafe {
276 TensorArg::from_raw_parts::<E>(
277 handle.handle,
278 handle.strides,
279 handle.shape,
280 vectorisation,
281 )
282 }
283 }
284
285 pub fn as_array_arg<E: JitElement>(&self, vectorisation: u8) -> ArrayArg<'_, R> {
287 unsafe {
288 ArrayArg::from_raw_parts::<E>(
289 &self.handle,
290 self.handle.size() as usize / core::mem::size_of::<E>(),
291 vectorisation,
292 )
293 }
294 }
295
296 pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool {
297 if !self.handle.can_mut() || !self.is_contiguous_buffer() {
298 return false;
299 }
300 let ndims = self.shape.num_dims();
301
302 for i in 0..ndims {
303 let shape_lhs = self.shape.dims[i];
304 let shape_rhs = rhs.shape.dims[i];
305
306 if shape_lhs < shape_rhs {
308 return false;
309 }
310 }
311
312 true
313 }
314
315 pub fn copy(&self) -> Self {
317 struct Copy;
318
319 #[cube]
320 impl<N: Numeric> NumericUnaryOp<N> for Copy {
321 type Options = ();
322
323 fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
324 input
325 }
326 }
327
328 impl NumericUnaryOpFamily for Copy {
329 type Options<N: Numeric> = ();
330 type Unary<N: Numeric> = Self;
331 }
332
333 let tensor = self.clone();
334
335 execute_with_dtype!(
336 tensor.dtype,
337 E,
338 launch_unary_numeric::<R, E, Copy, _>(tensor, |_| ())
339 )
340 }
341
342 pub fn can_mut(&self) -> bool {
344 self.handle.can_mut()
345 }
346
347 pub fn assert_is_on_same_device(&self, other: &Self) {
349 if self.device != other.device {
350 panic!(
351 "Both tensors should be on the same device {:?} != {:?}",
352 self.device, other.device
353 );
354 }
355 }
356
357 pub fn is_contiguous(&self) -> bool {
359 is_contiguous(&self.shape.dims, &self.strides)
360 }
361
362 pub fn is_contiguous_buffer(&self) -> bool {
365 self.shape.num_elements() * self.dtype.size() == self.handle.size() as usize
366 }
367}
368
369pub(crate) fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
370 if shape.is_empty() {
371 return true;
372 }
373
374 if shape.len() == 1 {
375 return strides[0] == 1;
376 }
377
378 let mut prev_stride = 1;
379 let mut current_num_elems_shape = 1;
380
381 for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() {
382 if i > 0 {
383 if current_num_elems_shape != *stride {
384 return false;
385 }
386
387 if prev_stride >= *stride {
388 return false;
389 }
390 }
391
392 current_num_elems_shape *= shape;
393 prev_stride = *stride;
394 }
395
396 true
397}
398
399#[cfg(test)]
400mod tests {
401 use crate::tensor::base::is_contiguous;
402
403 #[test]
404 fn is_contiguous_basic() {
405 assert!(is_contiguous(&[32, 32], &[32, 1]));
406 }
407
408 #[test]
409 fn is_contiguous_permuted() {
410 assert!(!is_contiguous(&[32, 32], &[1, 32]));
411 }
412
413 #[test]
414 fn is_contiguous_slice() {
415 assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));
416 }
417
418 #[test]
419 fn is_contiguous_4d_positive() {
420 assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));
421 }
422
423 #[test]
424 fn is_contiguous_4d_negative() {
425 assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));
426 }
427}