1use core::marker::PhantomData;
2use cubecl_core::tensor_line_size_parallel;
3use cubecl_core::{Runtime, server};
4use cubecl_core::{calculate_cube_count_elemwise, server::Allocation};
5use cubecl_core::{prelude::*, server::CopyDescriptor};
6use cubecl_runtime::server::Handle;
7
8pub struct TensorHandle<R, E>
10where
11 R: Runtime,
12 E: CubePrimitive,
13{
14 pub handle: server::Handle,
16 pub shape: Vec<usize>,
17 pub strides: Vec<usize>,
18 elem: PhantomData<E>,
19 runtime: PhantomData<R>,
20}
21
22impl<R, E> core::fmt::Debug for TensorHandle<R, E>
23where
24 R: Runtime,
25 E: CubePrimitive,
26{
27 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
28 f.write_fmt(format_args!(
29 "Tensor {{ shape: {:?}, strides: {:?}, dtype: {}}}",
30 self.shape,
31 self.strides,
32 core::any::type_name::<E>(),
33 ))
34 }
35}
36
37impl<R, E> Clone for TensorHandle<R, E>
38where
39 R: Runtime,
40 E: CubePrimitive,
41{
42 fn clone(&self) -> Self {
43 Self {
44 handle: self.handle.clone(),
45 shape: self.shape.clone(),
46 strides: self.strides.clone(),
47 elem: PhantomData,
48 runtime: PhantomData,
49 }
50 }
51}
52
53impl<R, E> TensorHandle<R, E>
54where
55 R: Runtime,
56 E: CubePrimitive,
57{
58 pub fn new(handle: server::Handle, shape: Vec<usize>, strides: Vec<usize>) -> Self {
60 Self {
61 handle,
62 shape,
63 strides,
64 elem: PhantomData,
65 runtime: PhantomData,
66 }
67 }
68
69 pub fn empty(client: &ComputeClient<R::Server>, shape: Vec<usize>) -> Self {
70 let elem_size = E::size().expect("To be a native type");
71 let Allocation { handle, strides } = client.empty_tensor(&shape, elem_size);
72
73 Self::new(handle, shape, strides)
74 }
75
76 pub fn from_ref(handle: &TensorHandleRef<'_, R>) -> Self {
78 Self {
79 handle: handle.handle.clone(),
80 shape: handle.shape.to_vec(),
81 strides: handle.strides.to_vec(),
82 elem: PhantomData,
83 runtime: PhantomData,
84 }
85 }
86
87 pub fn new_contiguous(shape: Vec<usize>, handle: Handle) -> Self {
89 let strides = Self::contiguous_strides(&shape);
90
91 Self {
92 handle,
93 shape,
94 strides,
95 elem: PhantomData,
96 runtime: PhantomData,
97 }
98 }
99
100 pub fn can_mut(&self) -> bool {
102 self.handle.can_mut()
103 }
104
105 pub fn as_ref(&self) -> TensorHandleRef<'_, R> {
106 unsafe {
107 TensorHandleRef::from_raw_parts(
108 &self.handle,
109 &self.strides,
110 &self.shape,
111 size_of::<E>(),
112 )
113 }
114 }
115
116 pub fn as_arg<'a>(&'a self, vectorisation: u8) -> TensorArg<'a, R> {
118 let handle: TensorHandleRef<'a, R> = self.as_ref();
119
120 unsafe {
121 TensorArg::from_raw_parts::<E>(
122 handle.handle,
123 handle.strides,
124 handle.shape,
125 vectorisation,
126 )
127 }
128 }
129
130 pub fn as_copy_descriptor<'a>(&'a self) -> CopyDescriptor<'a> {
131 CopyDescriptor {
132 binding: self.handle.clone().binding(),
133 shape: &self.shape,
134 strides: &self.strides,
135 elem_size: size_of::<E>(),
136 }
137 }
138
139 fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
140 let mut strides = Vec::with_capacity(shape.len());
141
142 let mut current = 1;
143 shape.iter().enumerate().rev().for_each(|(_, val)| {
144 strides.push(current);
145 current *= val;
146 });
147 strides.reverse();
148 strides
149 }
150}
151impl<R, E> TensorHandle<R, E>
152where
153 R: Runtime,
154 E: Numeric,
155{
156 pub fn zeros(client: &ComputeClient<R::Server>, shape: Vec<usize>) -> Self {
157 let num_elements: usize = shape.iter().product();
158 let rank = shape.len();
159 let output = Self::empty(client, shape);
160
161 let line_size = tensor_line_size_parallel(
162 R::supported_line_sizes().iter().cloned(),
163 &output.shape,
164 &output.strides,
165 rank - 1,
166 );
167
168 let cube_dim = CubeDim::default();
169 let cube_count = calculate_cube_count_elemwise(num_elements / line_size as usize, cube_dim);
170 let array_len = output.handle.size() as usize / size_of::<E>();
171
172 unsafe {
173 init::zeros_array::launch_unchecked::<E, R>(
174 client,
175 cube_count,
176 cube_dim,
177 ArrayArg::from_raw_parts::<E>(&output.handle, array_len, line_size),
178 )
179 };
180
181 output
182 }
183}
184
185pub(crate) mod init {
186 use cubecl::prelude::*;
187 use cubecl_core as cubecl;
188
189 #[cube(launch_unchecked)]
190 pub fn zeros_array<C: Numeric>(output: &mut Array<Line<C>>) {
191 if ABSOLUTE_POS < output.len() {
192 output[ABSOLUTE_POS] = Line::cast_from(C::from_int(0));
193 }
194 }
195}