1use crate::compute::KernelTask;
2use crate::frontend::TensorHandleRef;
3use crate::ir::Elem;
4use crate::pod::CubeElement;
5use crate::{calculate_cube_count_elemwise, CubeDim, Kernel, Runtime};
6use cubecl_runtime::client::ComputeClient;
7use cubecl_runtime::server::{Binding, CubeCount, Handle};
8
9pub enum CubeCountSettings {
11 Input { pos: usize },
12 Output { pos: usize },
13 Custom(CubeCount),
14}
15
16pub struct Execution<'h, K, R: Runtime, Scalars> {
17 scalars: Scalars,
18 client: ComputeClient<R::Server, R::Channel>,
19 kernel: K,
20 inputs: &'h [TensorHandleRef<'h, R>],
21 outputs: &'h [TensorHandleRef<'h, R>],
22}
23
24impl<'h, K, R: Runtime> Execution<'h, K, R, ()> {
25 pub fn start(
26 kernel: K,
27 client: ComputeClient<R::Server, R::Channel>,
28 ) -> Execution<'h, K, R, ()> {
29 Execution {
30 scalars: (),
31 client,
32 kernel,
33 inputs: &[],
34 outputs: &[],
35 }
36 }
37
38 #[allow(unused)]
39 pub fn inputs(self, inputs: &'h [TensorHandleRef<'h, R>]) -> Execution<'h, K, R, ()> {
40 Execution {
41 scalars: self.scalars,
42 client: self.client,
43 kernel: self.kernel,
44 inputs,
45 outputs: self.outputs,
46 }
47 }
48
49 pub fn outputs(self, outputs: &'h [TensorHandleRef<'h, R>]) -> Execution<'h, K, R, ()> {
50 Execution {
51 scalars: self.scalars,
52 client: self.client,
53 kernel: self.kernel,
54 inputs: self.inputs,
55 outputs,
56 }
57 }
58}
59
60impl<'h, K, R> Execution<'h, K, R, ()>
61where
62 K: Kernel + 'static,
63 R: Runtime,
64{
65 pub fn with_scalars<E>(self, scalars: &[E]) -> Execution<'h, K, R, (&[E],)> {
66 Execution {
67 scalars: (scalars,),
68 client: self.client,
69 kernel: self.kernel,
70 inputs: self.inputs,
71 outputs: self.outputs,
72 }
73 }
74 #[allow(unused)]
76 pub fn execute(self, launch: CubeCountSettings) {
77 execute_dynamic::<R, K, f32, f32, f32>(
78 self.inputs,
79 self.outputs,
80 None,
81 None,
82 None,
83 self.kernel,
84 launch,
85 self.client,
86 )
87 }
88}
89
90impl<'h, 'a, K, R, E> Execution<'h, K, R, (&'a [E],)>
91where
92 K: Kernel + 'static,
93 R: Runtime,
94 E: CubeElement,
95{
96 pub fn with_scalars<'b, E2>(
97 self,
98 scalars: &'b [E2],
99 ) -> Execution<'h, K, R, (&'a [E], &'b [E2])> {
100 Execution {
101 scalars: (self.scalars.0, scalars),
102 client: self.client,
103 kernel: self.kernel,
104 inputs: self.inputs,
105 outputs: self.outputs,
106 }
107 }
108
109 #[allow(unused)]
111 pub fn execute(self, launch: CubeCountSettings) {
112 execute_dynamic::<R, K, E, f32, f32>(
113 self.inputs,
114 self.outputs,
115 Some(self.scalars.0),
116 None,
117 None,
118 self.kernel,
119 launch,
120 self.client,
121 )
122 }
123}
124
125impl<'h, 'a, 'b, K, R, E1, E2> Execution<'h, K, R, (&'a [E1], &'b [E2])>
126where
127 K: Kernel + 'static,
128 R: Runtime,
129 E1: CubeElement,
130 E2: CubeElement,
131{
132 #[allow(unused, clippy::type_complexity)]
133 pub fn with_scalars<'c, E3>(
134 self,
135 scalars: &'c [E3],
136 ) -> Execution<'h, K, R, (&'a [E1], &'b [E2], &'c [E3])> {
137 Execution {
138 scalars: (self.scalars.0, self.scalars.1, scalars),
139 client: self.client,
140 kernel: self.kernel,
141 inputs: self.inputs,
142 outputs: self.outputs,
143 }
144 }
145 #[allow(clippy::too_many_arguments)]
147 pub fn execute(self, launch: CubeCountSettings)
148 where
149 K: Kernel + 'static,
150 R: Runtime,
151 {
152 execute_dynamic::<R, K, E1, E2, f32>(
153 self.inputs,
154 self.outputs,
155 Some(self.scalars.0),
156 Some(self.scalars.1),
157 None,
158 self.kernel,
159 launch,
160 self.client,
161 )
162 }
163}
164
165impl<K, R, E1, E2, E3> Execution<'_, K, R, (&[E1], &[E2], &[E3])>
166where
167 K: Kernel + 'static,
168 R: Runtime,
169 E1: CubeElement,
170 E2: CubeElement,
171 E3: CubeElement,
172{
173 #[allow(unused)]
175 pub fn execute(self, launch: CubeCountSettings) {
176 execute_dynamic::<R, K, E1, E2, E3>(
177 self.inputs,
178 self.outputs,
179 Some(self.scalars.0),
180 Some(self.scalars.1),
181 Some(self.scalars.2),
182 self.kernel,
183 launch,
184 self.client,
185 )
186 }
187}
188
189#[allow(clippy::too_many_arguments)]
190fn execute_dynamic<R, K, E1, E2, E3>(
191 inputs: &[TensorHandleRef<R>],
192 outputs: &[TensorHandleRef<R>],
193 scalars_1: Option<&[E1]>,
194 scalars_2: Option<&[E2]>,
195 scalars_3: Option<&[E3]>,
196 kernel: K,
197 launch: CubeCountSettings,
198 client: ComputeClient<R::Server, R::Channel>,
199) where
200 K: Kernel + 'static,
201 R: Runtime,
202 E1: CubeElement,
203 E2: CubeElement,
204 E3: CubeElement,
205{
206 let settings = execute_settings::<R, E1, E2, E3>(
207 inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client,
208 );
209
210 let mut handles = settings.handles_tensors;
211
212 handles.push(settings.handle_info.binding());
213 for handle in settings.handles_scalars.into_iter() {
214 handles.push(handle.binding());
215 }
216
217 let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
218 client.execute(kernel, settings.cube_count, handles);
219}
220
221struct ExecuteSettings {
222 handles_tensors: Vec<Binding>,
223 handle_info: Handle,
224 handles_scalars: Vec<Handle>,
225 cube_count: CubeCount,
226}
227
228#[allow(clippy::too_many_arguments)]
229fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeElement>(
230 inputs: &'a [TensorHandleRef<R>],
231 outputs: &'a [TensorHandleRef<R>],
232 scalars_1: Option<&[E1]>,
233 scalars_2: Option<&[E2]>,
234 scalars_3: Option<&[E3]>,
235 launch: CubeCountSettings,
236 client: &ComputeClient<R::Server, R::Channel>,
237) -> ExecuteSettings {
238 let mut info = Vec::new();
239 let mut handles = Vec::with_capacity(inputs.len() + outputs.len() + 2);
240
241 let mut register_info_tensor = |strides: &[usize], shape: &[usize]| {
243 if info.is_empty() {
244 info.push(strides.len() as u32);
245 }
246
247 for s in strides.iter() {
248 info.push(*s as u32);
249 }
250 for s in shape.iter() {
251 info.push(*s as u32);
252 }
253 };
254
255 let mut num_elems_output = 0;
256
257 for (i, input) in inputs.iter().enumerate() {
259 if let CubeCountSettings::Input { pos } = &launch {
260 if i == *pos {
261 num_elems_output = calculate_num_elems_dyn_rank(input.shape);
262 }
263 };
264 register_info_tensor(input.strides, input.shape);
265 handles.push(input.handle.clone().binding());
266 }
267
268 for (i, output) in outputs.iter().enumerate() {
270 if let CubeCountSettings::Output { pos } = &launch {
271 if i == *pos {
272 num_elems_output = calculate_num_elems_dyn_rank(output.shape);
273 }
274 };
275 register_info_tensor(output.strides, output.shape);
276 handles.push(output.handle.clone().binding());
277 }
278
279 if R::require_array_lengths() {
281 for input in inputs.iter() {
282 let len = calculate_num_elems_dyn_rank(input.shape);
283 info.push(len as u32);
284 }
285
286 for output in outputs.iter() {
287 let len = calculate_num_elems_dyn_rank(output.shape);
288 info.push(len as u32);
289 }
290 }
291
292 let info = client.create(bytemuck::cast_slice(&info));
293
294 let handles_scalars =
296 create_scalar_handles::<R, E1, E2, E3>(scalars_1, scalars_2, scalars_3, client);
297
298 let cube_count = match launch {
299 CubeCountSettings::Custom(count) => count,
300 _ => calculate_cube_count_elemwise(num_elems_output, CubeDim::default()),
301 };
302
303 ExecuteSettings {
304 handles_tensors: handles,
305 handle_info: info,
306 handles_scalars,
307 cube_count,
308 }
309}
310
311fn create_scalar_handles<R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeElement>(
312 scalars_0: Option<&[E1]>,
313 scalars_1: Option<&[E2]>,
314 scalars_2: Option<&[E3]>,
315 client: &ComputeClient<R::Server, R::Channel>,
316) -> Vec<Handle> {
317 let element_priority = |elem: Elem| match elem {
319 Elem::Float(_) | Elem::AtomicFloat(_) => 0,
320 Elem::Int(_) | Elem::AtomicInt(_) => 1,
321 Elem::UInt(_) | Elem::AtomicUInt(_) => 2,
322 Elem::Bool => panic!("Bool scalars are not supported"),
323 };
324 let scalar_priorities: [usize; 3] = [
325 element_priority(E1::cube_elem()),
326 element_priority(E2::cube_elem()),
327 element_priority(E3::cube_elem()),
328 ];
329
330 let mut handles_scalars = Vec::new();
331 for i in 0..3 {
332 for (j, scalar_priority) in scalar_priorities.iter().enumerate() {
333 if scalar_priority == &i {
334 if j == 0 {
335 if let Some(values) = &scalars_0 {
336 handles_scalars.push(client.create(bytemuck::cast_slice(values)));
337 }
338 } else if j == 1 {
339 if let Some(values) = &scalars_1 {
340 handles_scalars.push(client.create(bytemuck::cast_slice(values)));
341 }
342 } else if j == 2 {
343 if let Some(values) = &scalars_2 {
344 handles_scalars.push(client.create(bytemuck::cast_slice(values)));
345 }
346 }
347 }
348 }
349 }
350
351 handles_scalars
352}
353
354pub fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
355 let mut num_elems = 1;
356 for i in shape.iter() {
357 num_elems *= i;
358 }
359 num_elems
360}