1use core::marker::PhantomData;
2use cubecl_core::{Runtime, server, zspace::strides};
3use cubecl_core::{calculate_cube_count_elemwise, server::MemoryLayout};
4use cubecl_core::{ir::StorageType, zspace::metadata::Metadata};
5use cubecl_core::{prelude::*, server::CopyDescriptor};
6use cubecl_core::{
7 tensor_vector_size_parallel,
8 zspace::{Shape, Strides},
9};
10use cubecl_runtime::server::Handle;
11
12pub struct TensorHandle<R>
14where
15 R: Runtime,
16{
17 pub handle: server::Handle,
19 pub metadata: Box<Metadata>,
20 pub dtype: StorageType,
22 runtime: PhantomData<R>,
23}
24
25impl<R> core::fmt::Debug for TensorHandle<R>
26where
27 R: Runtime,
28{
29 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
30 f.write_fmt(format_args!(
31 "Tensor {{ shape: {:?}, strides: {:?}, dtype: {}}}",
32 self.shape(),
33 self.strides(),
34 self.dtype,
35 ))
36 }
37}
38
39impl<R> Clone for TensorHandle<R>
40where
41 R: Runtime,
42{
43 fn clone(&self) -> Self {
44 Self {
45 handle: self.handle.clone(),
46 metadata: self.metadata.clone(),
47 dtype: self.dtype,
48 runtime: PhantomData,
49 }
50 }
51}
52
53impl<R> TensorHandle<R>
54where
55 R: Runtime,
56{
57 pub fn new(
59 handle: server::Handle,
60 shape: impl Into<Shape>,
61 strides: impl Into<Strides>,
62 storage: impl Into<Type>,
63 ) -> Self {
64 Self {
65 handle,
66 metadata: Box::new(Metadata::new(shape, strides)),
67 dtype: storage.into().storage_type(),
68 runtime: PhantomData,
69 }
70 }
71
72 pub fn empty(
73 client: &ComputeClient<R>,
74 shape: impl Into<Shape>,
75 storage: impl Into<Type>,
76 ) -> Self {
77 let storage = storage.into();
78 let shape: Shape = shape.into();
79 let elem_size = storage.storage_type().size();
80 let MemoryLayout {
81 memory: handle,
82 strides,
83 } = client.empty_tensor(shape.clone(), elem_size);
84
85 Self::new(handle, shape, strides, storage)
86 }
87
88 pub fn new_contiguous(shape: impl Into<Shape>, handle: Handle, storage: StorageType) -> Self {
90 let shape = shape.into();
91 let strides = Self::contiguous_strides(&shape);
92
93 Self {
94 handle,
95 metadata: Box::new(Metadata::new(shape, strides)),
96 dtype: storage,
97 runtime: PhantomData,
98 }
99 }
100
101 pub fn can_mut(&self) -> bool {
103 self.handle.can_mut()
104 }
105
106 pub fn binding(self) -> TensorBinding<R> {
107 unsafe {
108 TensorBinding::from_raw_parts(self.handle, self.metadata.strides, self.metadata.shape)
109 }
110 }
111
112 pub fn into_arg(self) -> TensorArg<R> {
114 self.binding().into_tensor_arg()
115 }
116
117 pub fn into_copy_descriptor(self) -> CopyDescriptor {
118 CopyDescriptor {
119 handle: self.handle.binding(),
120 shape: self.metadata.shape,
121 strides: self.metadata.strides,
122 elem_size: self.dtype.size(),
123 }
124 }
125
126 pub fn required_address_type(&self) -> AddressType {
127 let len = self.handle.size() / self.dtype.size() as u64;
128 AddressType::from_len(len as usize)
129 }
130
131 pub fn shape(&self) -> &Shape {
132 self.metadata.shape()
133 }
134
135 pub fn strides(&self) -> &Strides {
136 self.metadata.strides()
137 }
138
139 fn contiguous_strides(shape: &[usize]) -> Strides {
140 let mut strides = strides![1; shape.len()];
141
142 let mut current = 1;
143 shape.iter().rev().enumerate().for_each(|(i, val)| {
144 strides[i] = current;
145 current *= val;
146 });
147 strides.reverse();
148 strides
149 }
150}
151impl<R> TensorHandle<R>
152where
153 R: Runtime,
154{
155 pub fn zeros(
156 client: &ComputeClient<R>,
157 shape: impl Into<Shape>,
158 dtype: impl Into<Type>,
159 ) -> Self {
160 let dtype = dtype.into();
161 let shape = shape.into();
162 let num_elements: usize = shape.iter().product();
163 let rank = shape.len();
164 let output = Self::empty(client, shape, dtype);
165 let dtype = dtype.storage_type();
166
167 let vector_size = tensor_vector_size_parallel(
168 client.io_optimized_vector_sizes(dtype.size()),
169 output.shape(),
170 output.strides(),
171 rank - 1,
172 );
173
174 let working_units = num_elements / vector_size as usize;
175 let cube_dim = CubeDim::new(client, working_units);
176 let cube_count = calculate_cube_count_elemwise(client, working_units, cube_dim);
177 let array_len = output.handle.size_in_used() as usize / dtype.size();
178
179 unsafe {
180 init::zeros_array::launch_unchecked(
181 client,
182 cube_count,
183 cube_dim,
184 output.required_address_type(),
185 vector_size,
186 ArrayArg::from_raw_parts(output.handle.clone(), array_len),
187 dtype,
188 )
189 };
190
191 output
192 }
193}
194
195pub(crate) mod init {
196 use cubecl::prelude::*;
197 use cubecl_core::{self as cubecl, ir::StorageType};
198
199 #[cube(launch_unchecked, address_type = "dynamic")]
200 pub fn zeros_array<C: Numeric, N: Size>(
201 output: &mut Array<Vector<C, N>>,
202 #[define(C)] _elem: StorageType,
203 ) {
204 if ABSOLUTE_POS < output.len() {
205 output[ABSOLUTE_POS] = Vector::cast_from(C::from_int(0));
206 }
207 }
208}