1use core::marker::PhantomData;
2use cubecl_core::{Runtime, server, zspace::strides};
3use cubecl_core::{calculate_cube_count_elemwise, server::Allocation};
4use cubecl_core::{ir::StorageType, zspace::metadata::Metadata};
5use cubecl_core::{prelude::*, server::CopyDescriptor};
6use cubecl_core::{
7 tensor_line_size_parallel,
8 zspace::{Shape, Strides},
9};
10use cubecl_runtime::server::Handle;
11
12pub struct TensorHandle<R>
14where
15 R: Runtime,
16{
17 pub handle: server::Handle,
19 pub metadata: Box<Metadata>,
20 pub dtype: StorageType,
22 runtime: PhantomData<R>,
23}
24
25impl<R> core::fmt::Debug for TensorHandle<R>
26where
27 R: Runtime,
28{
29 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
30 f.write_fmt(format_args!(
31 "Tensor {{ shape: {:?}, strides: {:?}, dtype: {}}}",
32 self.shape(),
33 self.strides(),
34 self.dtype,
35 ))
36 }
37}
38
39impl<R> Clone for TensorHandle<R>
40where
41 R: Runtime,
42{
43 fn clone(&self) -> Self {
44 Self {
45 handle: self.handle.clone(),
46 metadata: self.metadata.clone(),
47 dtype: self.dtype,
48 runtime: PhantomData,
49 }
50 }
51}
52
53impl<R> TensorHandle<R>
54where
55 R: Runtime,
56{
57 pub fn new(
59 handle: server::Handle,
60 shape: impl Into<Shape>,
61 strides: impl Into<Strides>,
62 storage: StorageType,
63 ) -> Self {
64 Self {
65 handle,
66 metadata: Box::new(Metadata::new(shape, strides)),
67 dtype: storage,
68 runtime: PhantomData,
69 }
70 }
71
72 pub fn empty(client: &ComputeClient<R>, shape: impl Into<Shape>, storage: StorageType) -> Self {
73 let shape = shape.into();
74 let elem_size = storage.size();
75 let Allocation { handle, strides } = client.empty_tensor(&shape, elem_size);
76
77 Self::new(handle, shape, strides, storage)
78 }
79
80 pub fn from_ref(handle: &TensorHandleRef<'_, R>, storage: StorageType) -> Self {
82 Self {
83 handle: handle.handle.clone(),
84 metadata: Box::new(Metadata::new(handle.shape, handle.strides)),
85 dtype: storage,
86 runtime: PhantomData,
87 }
88 }
89
90 pub fn new_contiguous(shape: impl Into<Shape>, handle: Handle, storage: StorageType) -> Self {
92 let shape = shape.into();
93 let strides = Self::contiguous_strides(&shape);
94
95 Self {
96 handle,
97 metadata: Box::new(Metadata::new(shape, strides)),
98 dtype: storage,
99 runtime: PhantomData,
100 }
101 }
102
103 pub fn can_mut(&self) -> bool {
105 self.handle.can_mut()
106 }
107
108 pub fn as_ref(&self) -> TensorHandleRef<'_, R> {
109 unsafe {
110 TensorHandleRef::from_raw_parts(
111 &self.handle,
112 self.strides(),
113 self.shape(),
114 self.dtype.size(),
115 )
116 }
117 }
118
119 pub fn as_arg<'a>(&'a self, line_size: LineSize) -> TensorArg<'a, R> {
121 let handle: TensorHandleRef<'a, R> = self.as_ref();
122
123 unsafe {
124 TensorArg::from_raw_parts_and_size(
125 handle.handle,
126 handle.strides,
127 handle.shape,
128 line_size,
129 handle.elem_size,
130 )
131 }
132 }
133
134 pub fn as_copy_descriptor<'a>(&'a self) -> CopyDescriptor<'a> {
135 CopyDescriptor {
136 binding: self.handle.clone().binding(),
137 shape: self.shape(),
138 strides: self.strides(),
139 elem_size: self.dtype.size(),
140 }
141 }
142
143 pub fn required_address_type(&self) -> AddressType {
144 let len = self.handle.size() / self.dtype.size() as u64;
145 AddressType::from_len(len as usize)
146 }
147
148 pub fn shape(&self) -> &Shape {
149 self.metadata.shape()
150 }
151
152 pub fn strides(&self) -> &Strides {
153 self.metadata.strides()
154 }
155
156 fn contiguous_strides(shape: &[usize]) -> Strides {
157 let mut strides = strides![1; shape.len()];
158
159 let mut current = 1;
160 shape.iter().rev().enumerate().for_each(|(i, val)| {
161 strides[i] = current;
162 current *= val;
163 });
164 strides.reverse();
165 strides
166 }
167}
168impl<R> TensorHandle<R>
169where
170 R: Runtime,
171{
172 pub fn zeros(client: &ComputeClient<R>, shape: impl Into<Shape>, dtype: StorageType) -> Self {
173 let shape = shape.into();
174 let num_elements: usize = shape.iter().product();
175 let rank = shape.len();
176 let output = Self::empty(client, shape, dtype);
177
178 let line_size = tensor_line_size_parallel(
179 client.io_optimized_line_sizes(dtype.size()),
180 output.shape(),
181 output.strides(),
182 rank - 1,
183 );
184
185 let working_units = num_elements / line_size as usize;
186 let cube_dim = CubeDim::new(client, working_units);
187 let cube_count = calculate_cube_count_elemwise(client, working_units, cube_dim);
188 let array_len = output.handle.size() as usize / dtype.size();
189
190 unsafe {
191 init::zeros_array::launch_unchecked(
192 client,
193 cube_count,
194 cube_dim,
195 output.required_address_type(),
196 ArrayArg::from_raw_parts_and_size(
197 &output.handle,
198 array_len,
199 line_size,
200 dtype.size(),
201 ),
202 dtype,
203 )
204 .expect("Should be able to launch the kernel all the time")
205 };
206
207 output
208 }
209}
210
211pub(crate) mod init {
212 use cubecl::prelude::*;
213 use cubecl_core::{self as cubecl, ir::StorageType};
214
215 #[cube(launch_unchecked, address_type = "dynamic")]
216 pub fn zeros_array<C: Numeric>(output: &mut Array<Line<C>>, #[define(C)] _elem: StorageType) {
217 if ABSOLUTE_POS < output.len() {
218 output[ABSOLUTE_POS] = Line::cast_from(C::from_int(0));
219 }
220 }
221}