1use core::marker::PhantomData;
2use cubecl_core::ir::StorageType;
3use cubecl_core::tensor_line_size_parallel;
4use cubecl_core::{Runtime, server};
5use cubecl_core::{calculate_cube_count_elemwise, server::Allocation};
6use cubecl_core::{prelude::*, server::CopyDescriptor};
7use cubecl_runtime::server::Handle;
8
9pub struct TensorHandle<R>
11where
12 R: Runtime,
13{
14 pub handle: server::Handle,
16 pub shape: Vec<usize>,
17 pub strides: Vec<usize>,
18 pub dtype: StorageType,
20 runtime: PhantomData<R>,
21}
22
23impl<R> core::fmt::Debug for TensorHandle<R>
24where
25 R: Runtime,
26{
27 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
28 f.write_fmt(format_args!(
29 "Tensor {{ shape: {:?}, strides: {:?}, dtype: {}}}",
30 self.shape, self.strides, self.dtype,
31 ))
32 }
33}
34
35impl<R> Clone for TensorHandle<R>
36where
37 R: Runtime,
38{
39 fn clone(&self) -> Self {
40 Self {
41 handle: self.handle.clone(),
42 shape: self.shape.clone(),
43 strides: self.strides.clone(),
44 dtype: self.dtype,
45 runtime: PhantomData,
46 }
47 }
48}
49
50impl<R> TensorHandle<R>
51where
52 R: Runtime,
53{
54 pub fn new(
56 handle: server::Handle,
57 shape: Vec<usize>,
58 strides: Vec<usize>,
59 storage: StorageType,
60 ) -> Self {
61 Self {
62 handle,
63 shape,
64 strides,
65 dtype: storage,
66 runtime: PhantomData,
67 }
68 }
69
70 pub fn empty(
71 client: &ComputeClient<R::Server>,
72 shape: Vec<usize>,
73 storage: StorageType,
74 ) -> Self {
75 let elem_size = storage.size();
76 let Allocation { handle, strides } = client.empty_tensor(&shape, elem_size);
77
78 Self::new(handle, shape, strides, storage)
79 }
80
81 pub fn from_ref(handle: &TensorHandleRef<'_, R>, storage: StorageType) -> Self {
83 Self {
84 handle: handle.handle.clone(),
85 shape: handle.shape.to_vec(),
86 strides: handle.strides.to_vec(),
87 dtype: storage,
88 runtime: PhantomData,
89 }
90 }
91
92 pub fn new_contiguous(shape: Vec<usize>, handle: Handle, storage: StorageType) -> Self {
94 let strides = Self::contiguous_strides(&shape);
95
96 Self {
97 handle,
98 shape,
99 strides,
100 dtype: storage,
101 runtime: PhantomData,
102 }
103 }
104
105 pub fn can_mut(&self) -> bool {
107 self.handle.can_mut()
108 }
109
110 pub fn as_ref(&self) -> TensorHandleRef<'_, R> {
111 unsafe {
112 TensorHandleRef::from_raw_parts(
113 &self.handle,
114 &self.strides,
115 &self.shape,
116 self.dtype.size(),
117 )
118 }
119 }
120
121 pub fn as_arg<'a>(&'a self, vectorisation: u8) -> TensorArg<'a, R> {
123 let handle: TensorHandleRef<'a, R> = self.as_ref();
124
125 unsafe {
126 TensorArg::from_raw_parts_and_size(
127 handle.handle,
128 handle.strides,
129 handle.shape,
130 vectorisation,
131 handle.elem_size,
132 )
133 }
134 }
135
136 pub fn as_copy_descriptor<'a>(&'a self) -> CopyDescriptor<'a> {
137 CopyDescriptor {
138 binding: self.handle.clone().binding(),
139 shape: &self.shape,
140 strides: &self.strides,
141 elem_size: self.dtype.size(),
142 }
143 }
144
145 fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
146 let mut strides = Vec::with_capacity(shape.len());
147
148 let mut current = 1;
149 shape.iter().enumerate().rev().for_each(|(_, val)| {
150 strides.push(current);
151 current *= val;
152 });
153 strides.reverse();
154 strides
155 }
156}
157impl<R> TensorHandle<R>
158where
159 R: Runtime,
160{
161 pub fn zeros(client: &ComputeClient<R::Server>, shape: Vec<usize>, dtype: StorageType) -> Self {
162 let num_elements: usize = shape.iter().product();
163 let rank = shape.len();
164 let output = Self::empty(client, shape, dtype);
165
166 let line_size = tensor_line_size_parallel(
167 R::supported_line_sizes().iter().cloned(),
168 &output.shape,
169 &output.strides,
170 rank - 1,
171 );
172
173 let cube_dim = CubeDim::default();
174 let cube_count = calculate_cube_count_elemwise(num_elements / line_size as usize, cube_dim);
175 let array_len = output.handle.size() as usize / dtype.size();
176
177 unsafe {
178 init::zeros_array::launch_unchecked::<R>(
179 client,
180 cube_count,
181 cube_dim,
182 ArrayArg::from_raw_parts_and_size(
183 &output.handle,
184 array_len,
185 line_size,
186 dtype.size(),
187 ),
188 dtype,
189 )
190 };
191
192 output
193 }
194}
195
196pub(crate) mod init {
197 use cubecl::prelude::*;
198 use cubecl_core::{self as cubecl, ir::StorageType};
199
200 #[cube(launch_unchecked)]
201 pub fn zeros_array<C: Numeric>(output: &mut Array<Line<C>>, #[define(C)] _elem: StorageType) {
202 if ABSOLUTE_POS < output.len() {
203 output[ABSOLUTE_POS] = Line::cast_from(C::from_int(0));
204 }
205 }
206}